Creates training data from MNIST to ensure training is performed on same images for all models (FF + M1 + M2)

In [1]:
import math 
import torch
import numpy as np
from torch import nn, Tensor
from torch.nn.functional import softplus
from torch.distributions import Distribution
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision import transforms, utils
from functools import reduce

# Flatten the images into a vector
flatten = lambda x: ToTensor()(x).view(28**2)
from torch.distributions import Bernoulli
binarization = lambda x: torch.bernoulli(x)
c_transform  = transforms.Compose([flatten,binarization])

# Define the train and test sets
dset_train = MNIST("./", train=True,  transform=c_transform, download=True)
dset_test  = MNIST("./", train=False, transform=c_transform)

# The digit classes to use
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

def stratified_sampler(labels):
    """Sampler that only picks datapoints corresponding to the specified classes"""
    (indices,) = np.where(reduce(lambda x, y: x | y, [labels.numpy() == i for i in classes]))
    indices = torch.from_numpy(indices)
    return SubsetRandomSampler(indices)

batch_size = 64
eval_batch_size = 64
# The loaders perform the actual work
train_loader = DataLoader(dset_train, batch_size=batch_size,
                          sampler=stratified_sampler(dset_train.train_labels))
test_loader  = DataLoader(dset_test, batch_size=eval_batch_size, 
                          sampler=stratified_sampler(dset_test.test_labels))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Done!








In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
total = np.zeros(10)
r_images = torch.ones((100, 784))
r_labels = torch.ones((100))
j = 0

while total.sum() < 100: #Random which items remembered each time
  images, labels = next(iter(train_loader)) 
  for i in range(63):
     if total[labels[i].item()] < 10:
      total[labels[i].item()] = total[labels[i].item()]+1
      r_images[j] = images[i]
      r_labels[j] = labels[i]
      j = j+1

# Check:      
#torch.histc(r_labels, bins=10, min=0, max=9)
#r_labels

path = "drive//My Drive/Vae/r_images.pt"
torch.save(r_images, path)
path = "drive//My Drive/Vae/r_labels.pt"
torch.save(r_labels, path)

Non-labelled training for semi-supervised (M2): 

In [35]:
total = np.zeros(10)
r_images_E100 = torch.ones((100, 784))
r_labels_E100 = torch.ones((100))
j = 0

while total.sum() < 100: #Random which items remembered each time
  images, labels = next(iter(train_loader)) 
  for i in range(63):
     if total[labels[i].item()] < 10:
      total[labels[i].item()] = total[labels[i].item()]+1 
      r_images_E100[j] = images[i]
      r_labels_E100[j] = 10
      j = j+1     

path = "drive//My Drive/Vae/r_images_E100.pt"
torch.save(r_images_E100, path)
path = "drive//My Drive/Vae/r_labels_E100.pt"
torch.save(r_labels_E100, path)      

In [36]:
total = np.zeros(10)
r_images_E1000 = torch.ones((1000, 784))
r_labels_E1000 = torch.ones((1000))
j = 0

while total.sum() < 1000: #Random which items remembered each time
  images, labels = next(iter(train_loader)) 
  for i in range(63):
     if total[labels[i].item()] < 100:
      total[labels[i].item()] = total[labels[i].item()]+1 
      r_images_E1000[j] = images[i]
      r_labels_E1000[j] = 10
      j = j+1 

path = "drive//My Drive/Vae/r_images_E1000.pt"
torch.save(r_images_E1000, path)
path = "drive//My Drive/Vae/r_labels_E1000.pt"
torch.save(r_labels_E1000, path)          

In [37]:
total = np.zeros(10)
r_images_E10000 = torch.ones((10000, 784))
r_labels_E10000 = torch.ones((10000))
j = 0

while total.sum() < 10000: #Random which items remembered each time
  images, labels = next(iter(train_loader)) 
  for i in range(63):
     if total[labels[i].item()] < 1000:
      total[labels[i].item()] = total[labels[i].item()]+1 
      r_images_E10000[j] = images[i]
      r_labels_E10000[j] = 10
      j = j+1     

path = "drive//My Drive/Vae/r_images_E10000.pt"
torch.save(r_images_E10000, path)
path = "drive//My Drive/Vae/r_labels_E10000.pt"
torch.save(r_labels_E10000, path)      