In [6]:
import os
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler

In [2]:
def mnist_set(norm=1):
    if norm == 1:
        trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    elif norm == 2:
        trans = transforms.Compose([
            transforms.ToTensor()
            # it will automatically do x/255.0 by the transforms
            # transforms.Lambda(lambda x: x / 255.0)
        ])
    else:
        raise NotImplementedError

    dir_path = os.path.join(os.environ['HOME'], 'project/data/dataset/mnist')
    train_set = datasets.MNIST(dir_path, train=True, download=True, transform=trans)
    test_set = datasets.MNIST(dir_path, train=False, transform=trans)
    return train_set, test_set, (1, 28, 28), 10

In [4]:

class InfiniteSampler(Sampler):
    def __init__(self, num_samples, index=None):
        # index -> the index of a subset
        self.order = index
        if self.order is not None:
            self.num_samples = self.order.shape[0]
        else:
            self.num_samples = num_samples

    def __iter__(self):
        return iter(self.loop())

    def __len__(self):
        return 2 ** 31

    def loop(self):
        # return a sequence of the index
        i = 0
        if self.order is None:
            # first epoch use 0~N-1, easy to compare data.
            self.order = np.arange(self.num_samples)
        while True:
            yield self.order[i]
            i += 1
            if i >= self.num_samples:
                # generate new permutations.
                self.order = np.random.permutation(self.order)
                i = 0

In [12]:
train_set, test_set, shape, num_classes = mnist_set(norm=2)
test_loader = DataLoader(test_set, 128, False)
num_examples = len(train_set)

size = 1000

# the different part from pytorch_sup.py
# select a subset, each category has args.size/10 data
_train_set_y = train_set.targets.numpy()
if size > 50000:
    print("The numbers of each category in MNIST's training data are not the same")
    raise NotImplementedError
ind = np.concatenate([np.where(_train_set_y == i)[0][: int(size / 10)] for i in range(10)])
# in-place shuffling
np.random.shuffle(ind)
train_iter = iter(DataLoader(train_set, 128, num_workers=0, sampler=InfiniteSampler(ind.shape[0], ind)))

In [14]:
import os
import sys

import pickle
import numpy as np


def load_npz_as_dict(path):
    data = np.load(path)
    return {key: data[key] for key in data}


def augmentation(images, random_crop=True, random_flip=True):
    # random crop and random flip
    h, w = images.shape[2], images.shape[3]
    pad_size = 2
    aug_images = []
    padded_images = np.pad(images, ((0, 0), (0, 0), (pad_size, pad_size), (pad_size, pad_size)), 'reflect')
    for image in padded_images:
        if random_flip:
            image = image[:, :, ::-1] if np.random.uniform() > 0.5 else image
        if random_crop:
            offset_h = np.random.randint(0, 2 * pad_size)
            offset_w = np.random.randint(0, 2 * pad_size)
            image = image[:, offset_h:offset_h + h, offset_w:offset_w + w]
        else:
            image = image[:, pad_size:pad_size + h, pad_size:pad_size + w]
        aug_images.append(image)
    ret = np.stack(aug_images)
    assert ret.shape == images.shape
    return ret


def _augmentation(images, trans=True, flip=True):
    # shape of `image' [N, K, W, H]
    assert images.ndim == 4
    return augmentation(images, trans, flip)


class Data:
    # Data Shuffling from vat_chainer
    # MNIST won't use it to get batches of data, only needs self.data, self.label
    def __init__(self, data, label):
        self.data = data
        self.label = label
        self.index = np.arange(self.N)

    @property
    def N(self):
        return len(self.data)

    def get(self, n=None, shuffle=True, aug_trans=False, aug_flip=False):
        if shuffle:
            ind = np.random.permutation(self.data.shape[0])
        else:
            ind = np.arange(self.data.shape[0])
        if n is None:
            n = self.data.shape[0]
        index = ind[:n]
        batch_data = self.data[index]
        batch_label = self.label[index]
        if aug_trans or aug_flip:
            batch_data = _augmentation(batch_data, aug_trans, aug_flip)
        return batch_data, batch_label


def load_mnist_dataset():
    if sys.version_info.major == 3:
        dataset = pickle.load(open('dataset/mnist.pkl', 'rb'), encoding="bytes")
    else:
        dataset = pickle.load(open('dataset/mnist.pkl', 'rb'))
    train_set_x = np.concatenate((dataset[0][0], dataset[1][0]), axis=0).astype("float32")
    train_set_y = np.concatenate((dataset[0][1], dataset[1][1]), axis=0)
    return (train_set_x, train_set_y), (dataset[2][0], dataset[2][1])


def load_mnist_for_semi_sup(n_l=1000, n_v=1000):
    dataset = load_mnist_dataset()

    _train_set_x, _train_set_y = dataset[0]
    test_set_x, test_set_y = dataset[1]
    test_set_x = test_set_x.reshape(test_set_x.shape[0], 1, 28, 28)

    rand_ind = np.random.permutation(_train_set_x.shape[0])  # the seed from outer space can control this one
    _train_set_x = _train_set_x[rand_ind]
    _train_set_y = _train_set_y[rand_ind]

    s_c = int(n_l / 10.0)
    train_set_x = np.zeros((n_l, 28 ** 2))
    train_set_y = np.zeros(n_l)
    for i in range(10):
        ind = np.where(_train_set_y == i)[0]
        train_set_x[i * s_c:(i + 1) * s_c, :] = _train_set_x[ind[0:s_c], :]
        train_set_y[i * s_c:(i + 1) * s_c] = _train_set_y[ind[0:s_c]]
        # remove them from the set
        _train_set_x = np.delete(_train_set_x, ind[0:s_c], 0)
        _train_set_y = np.delete(_train_set_y, ind[0:s_c])

    l_rand_ind = np.random.permutation(train_set_x.shape[0])   # shuffle from uniform sequence to random permutation
    train_set_x = train_set_x[l_rand_ind].astype("float32").reshape(l_rand_ind.shape[0], 1, 28, 28)
    train_set_y = train_set_y[l_rand_ind]

    valid_set_x = _train_set_x[:n_v]
    valid_set_x = valid_set_x.reshape(n_v, 1, 28, 28)
    valid_set_y = _train_set_y[:n_v]
    train_set_ul_x = _train_set_x[n_v:]
    train_set_ul_x = train_set_ul_x.reshape(train_set_ul_x.shape[0], 1, 28, 28)
    train_set_ul_y = _train_set_y[n_v:]
    # Will unlabeled set contain labeled points?
    # train_set_ul_x = np.concatenate((train_set_x, _train_set_x[n_v:]), axis=0)
    # train_set_ul_y = np.concatenate((train_set_y, _train_set_y[n_v:]), axis=0)
    # train_set_ul_x = train_set_ul_x[np.random.permutation(train_set_ul_x.shape[0])]
    # ul_y is useless
    # train_set_ul_y = train_set_ul_y[np.random.permutation(train_set_ul_x.shape[0])]

    return (train_set_x, train_set_y), (train_set_ul_x, train_set_ul_y), (valid_set_x, valid_set_y), (test_set_x, test_set_y)


def load_dataset(dirpath, size=0):
    if 'mnist' in dirpath:
        train_l, train_ul, val_set, test_set = load_mnist_for_semi_sup(n_l=size)
        train_l = {'images': train_l[0], 'labels': train_l[1]}
        train_ul = {'images': train_ul[0], 'labels': train_ul[1]}
        # use val_set or test_set
        test = {'images': test_set[0], 'labels': test_set[1]}
    else:
        raise NotImplementedError
    return Data(train_l['images'], train_l['labels'].astype(np.int32)), \
           Data(train_ul['images'], train_ul['labels'].astype(np.int32)), \
           Data(test['images'], test['labels'].astype(np.int32))

In [15]:
train_l, train_ul, test_set = load_dataset("mnist", size=100)
x, t = train_l.get(128)
images = torch.FloatTensor(x)
labels = torch.LongTensor(t)

FileNotFoundError: [Errno 2] No such file or directory: 'dataset/mnist.pkl'