# “Unsupervised” learning on MNIST data

## Overview

By running this notebook you can reproduce experiment results with training the model for classification of MNIST images using a very small number (10 images per each digit class) of labeled examples.

## Loading MNIST dataset

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from torchvision import datasets

num_workers = 0
batch_size = 20
device = torch.device("cuda")

transform = transforms.ToTensor()

train_data = datasets.MNIST(root='data', train=True,  download=True, transform=transform)
test_data  = datasets.MNIST(root='data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_data,  batch_size=batch_size, num_workers=num_workers, shuffle=True)

##Label vectors

Instead of labels provided in the original MNIST dataset, we use 10 normalized label vectors of 25 positive reals. Those label vectors should meet a simple criteria: each pair of vectors should have cosine similarity not bigger than 0.5. That maximizes a mutual distance between all the label vectors. In other words, the label vectors are 10 cluster centroids with the distance between each other bigger than a specific threshold (0.5).

In [None]:
vector_labels = torch.tensor(
  [
    [0.0788, 0.1254, 0.0268, 0.2183, 0.3003, 0.3279, 0.0510, 0.0439, 0.3552,
      0.0749, 0.0862, 0.3351, 0.0506, 0.1168, 0.1643, 0.1960, 0.2681, 0.2267,
      0.1867, 0.3841, 0.0528, 0.1355, 0.0233, 0.0842, 0.2447], # digit 0
    [0.0751, 0.1258, 0.2656, 0.0132, 0.0048, 0.0459, 0.0123, 0.0921, 0.0722,
      0.4720, 0.0563, 0.0858, 0.1946, 0.0289, 0.0549, 0.0889, 0.0180, 0.0328,
      0.2810, 0.0050, 0.2128, 0.0508, 0.6535, 0.0418, 0.2278],  # digit 1
    [0.2016, 0.2286, 0.0246, 0.0350, 0.0284, 0.0113, 0.5795, 0.0695, 0.0989,
      0.0515, 0.5213, 0.0253, 0.2721, 0.0178, 0.0095, 0.0061, 0.0243, 0.0309,
      0.1076, 0.0285, 0.2499, 0.1712, 0.0729, 0.3033, 0.0366],  # digit 2
    [0.0282, 0.1384, 0.0564, 0.0931, 0.0529, 0.0337, 0.1164, 0.1348, 0.0838,
      0.0159, 0.3105, 0.0114, 0.1378, 0.0378, 0.0598, 0.7858, 0.0471, 0.1663,
      0.0565, 0.0347, 0.0680, 0.2686, 0.0610, 0.2588, 0.0746],  # digit 3
    [0.0368, 0.1348, 0.6532, 0.1854, 0.0323, 0.1109, 0.1187, 0.0222, 0.0124,
      0.1473, 0.2992, 0.0402, 0.0864, 0.2666, 0.0860, 0.0477, 0.2516, 0.2643,
      0.0793, 0.1082, 0.3297, 0.0828, 0.0344, 0.0016, 0.1499],  # digit 4
    [0.4528, 0.1523, 0.0251, 0.0511, 0.1991, 0.0561, 0.1557, 0.6983, 0.0731,
      0.1860, 0.0884, 0.0276, 0.1971, 0.1252, 0.1858, 0.0104, 0.0453, 0.1107,
      0.0252, 0.0585, 0.0470, 0.1628, 0.1336, 0.0263, 0.1031],  # digit 5
    [0.1442, 0.0414, 0.0260, 0.1051, 0.1011, 0.0484, 0.0940, 0.0715, 0.0459,
      0.5043, 0.0312, 0.0336, 0.1456, 0.5837, 0.2772, 0.0289, 0.4086, 0.0950,
      0.1037, 0.0274, 0.0142, 0.0024, 0.0353, 0.2247, 0.0419],  # digit 6
    [0.0253, 0.0100, 0.0521, 0.1290, 0.1214, 0.1809, 0.0092, 0.0036, 0.0223,
      0.0029, 0.0192, 0.0209, 0.5189, 0.1902, 0.2108, 0.0400, 0.0302, 0.1245,
      0.4193, 0.0728, 0.0449, 0.5350, 0.2411, 0.1886, 0.0111],  # digit 7
    [0.1489, 0.7321, 0.0011, 0.0065, 0.0215, 0.0462, 0.1617, 0.0958, 0.1073,
      0.0546, 0.0956, 0.0459, 0.0222, 0.1168, 0.2183, 0.0008, 0.0317, 0.0330,
      0.0749, 0.3990, 0.0605, 0.1267, 0.3394, 0.1195, 0.0142],  # digit 8
    [0.5247, 0.0227, 0.2495, 0.0037, 0.1025, 0.2442, 0.0776, 0.0051, 0.0561,
      0.0639, 0.0168, 0.1553, 0.1897, 0.0241, 0.0676, 0.1564, 0.0050, 0.0364,
      0.0179, 0.0769, 0.0420, 0.0796, 0.2302, 0.5864, 0.2758]  # digit 9
  ]
)

## The model

We use a simple model that contains 3 convolution layers with Relu activation function, each followed by a max-pooling layer.

For a given image of 28 x 28 pixels with a handwritten digit, the model generates a vector of 25 positive reals.

The loss function calculates the distance between the generated vector and the label vector associated with the class of a given image.

In [None]:
def loss_fn(x, y):
  return (1 - nn.CosineSimilarity(eps=1e-6)(x, y)).sum() / x.size(0)

class MNIST_Classifier(nn.Module):
    def __init__(self):
      super(MNIST_Classifier, self).__init__()

      self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
      self.conv2 = nn.Conv2d(16, 8, 3, padding=2)
      self.conv3 = nn.Conv2d(8,  1, 3, padding=2)

      self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = self.pool(F.relu(self.conv3(x)))

      return x

    def init_weights(self):
        torch.nn.init.xavier_uniform_(self.conv1.weight)
        torch.nn.init.xavier_uniform_(self.conv2.weight)
        torch.nn.init.xavier_uniform_(self.conv3.weight)

## Initial training set

We randomly choose 10 images per each class of digits from the MNIST training set and save in a separate training set.

In [None]:
def sort_train_loader(train_loader, sample_size = 10, seed = 0):
  original_train_loader = False

  if 'targets' in set(dir(train_loader.dataset)):
    original_train_loader = True

  if original_train_loader:
    train_targets = train_loader.dataset.targets
  else:
    train_targets = train_loader.dataset.tensors[1]

  train_loader_digits = []

  torch.manual_seed(seed)

  for i in range(vector_labels.size(0)):
    if original_train_loader:
      ds = train_loader.dataset.data[train_targets == i].unsqueeze(1).float() / 255
    else:
      ds = train_loader.dataset.tensors[0][train_targets == i].unsqueeze(1).float() / 255

    train_loader_digits.append(
      torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(
          ds[torch.randperm(ds.size(0))][0:sample_size],
          train_targets[train_targets == i][0:sample_size]
        ),
        batch_size  = batch_size,
        num_workers = num_workers,
        shuffle     = True
      )
    )

  return train_loader_digits

train_loader_digits = sort_train_loader(train_loader, sample_size = 10)

## Training initial model

We use 10 separate models, one model per each class of images. We train each model on its own training set of only 10 images. Also as was mentioned, each model is associated with an appropriate label vector that represents one of 10 classes of images with handwritten digits.

In [None]:
import copy

def train_model(model, train_loader_digit, opt, n_epochs=20):
  for epoch in range(1, n_epochs + 1):
    train_loss = 0.0

    for data in train_loader_digit:
      images, labels = data

      opt.zero_grad()

      out = model(images.to(device)).view(images.size(0), 25)

      loss = loss_fn(out, torch.tensor([vector_labels[i].tolist() for i in labels]).to(device)) * images.size(0)

      loss.backward()

      opt.step()

      train_loss += (loss.item() * images.size(0))

    train_loss = train_loss / train_loader_digit.dataset.tensors[0].size(0)

  return train_loss

def choose_best_model(digit=0, epochs=50, min_loss=0.35, max_tries=100):
  try_num = 0
  best_model = None
  best_loss = None

  while try_num < max_tries:
    model  = MNIST_Classifier().to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=0.01)

    model.init_weights()

    model.train()

    loss = train_model(model, train_loader_digits[digit], opt, epochs)

    if loss < min_loss:
      best_model = copy.deepcopy(model)

      best_loss = loss

      break

    if best_loss is None or loss < best_loss:
      best_model = copy.deepcopy(model)

      best_loss = loss

    try_num += 1

  return best_model, best_loss

def train_digit_models(epochs = 50, max_tries = 50, min_loss = 0.35, seed = 0):
  torch.manual_seed(seed)

  digit_models = []

  for digit in range(10):
    print('Training the model for digit {} (training set size {})'.format(digit, train_loader_digits[digit].dataset.tensors[0].size(0)))

    model, loss = choose_best_model(digit, epochs=epochs, min_loss=min_loss, max_tries=max_tries)

    print('Best loss for {} digit: {:.6f}'.format(digit, loss))

    digit_models.append(copy.deepcopy(model))

  return digit_models

vector_labels = vector_labels[torch.randperm(vector_labels.size(0))]

digit_models = train_digit_models(epochs = 50, max_tries = 50, min_loss = 0.35)

Training the model for digit 0 (training set size 10)
Best loss for 0 digit: 0.164769
Training the model for digit 1 (training set size 10)
Best loss for 1 digit: 0.301143
Training the model for digit 2 (training set size 10)
Best loss for 2 digit: 0.152520
Training the model for digit 3 (training set size 10)
Best loss for 3 digit: 0.142487
Training the model for digit 4 (training set size 10)
Best loss for 4 digit: 0.170634
Training the model for digit 5 (training set size 10)
Best loss for 5 digit: 0.156419
Training the model for digit 6 (training set size 10)
Best loss for 6 digit: 0.647418
Training the model for digit 7 (training set size 10)
Best loss for 7 digit: 0.243345
Training the model for digit 8 (training set size 10)
Best loss for 8 digit: 0.284011
Training the model for digit 9 (training set size 10)
Best loss for 9 digit: 0.087508


## Classification of digit images

Our initial model has a very limited knowledge about handwritten images. In the next step we apply our initial model to the whole training set (60K images) and label the images for which the model returns cosine similarity bigger than 0.895 threshold value.

It will give us new labeled image examples that can be used for training the model that will have a much better performance.

In [None]:
def classify_images(images, labels):
  loss = []

  for j in range(len(digit_models)):
    out = digit_models[j](images.to(device)).view(images.size(0), 25)

    loss.append(nn.CosineSimilarity(eps=1e-6)(out, torch.tensor([vector_labels[i].tolist() for i in labels]).to(device)).tolist())

  return loss

classified_cnt = 0

min_similarity = 0.895

i = 0

selected_images = []
selected_labels = []

for data in train_loader:
  images, labels = data

  loss = classify_images(images, labels)

  loss_tr = torch.tensor(loss).t()

  selected_ind = torch.arange(images.size(0))[(loss_tr > min_similarity).any(1)]

  selected_images.append(images[selected_ind])
  selected_labels.append(loss_tr[selected_ind,:].argsort(1)[:,-1])

  if (loss_tr[selected_ind,:].argsort(1)[:,-1] - labels[selected_ind]).sum() > 0:
    print(loss_tr[selected_ind,:].argsort(1)[:,-1])
    print(labels[selected_ind])
  else:
    classified_cnt += selected_ind.size(0)

    if i % 100 == 0:
      print('Classified {} images'.format(classified_cnt))

  i += 1

print('Total classified images {}'.format(classified_cnt))

selected_images = torch.cat(selected_images) * 255
selected_labels = torch.cat(selected_labels)

new_train_loader = torch.utils.data.DataLoader(
  torch.utils.data.TensorDataset(
    selected_images.squeeze(1),
    selected_labels
  ),
  batch_size  = batch_size,
  num_workers = num_workers,
  shuffle     = True
)

selected_labels.unique(return_counts=True)

Classified 15 images
Classified 1396 images
Classified 2757 images
Classified 4105 images
Classified 5486 images
Classified 6852 images
Classified 8211 images
Classified 9544 images
Classified 10880 images
Classified 12258 images
Classified 13633 images
Classified 15018 images
Classified 16378 images
Classified 17733 images
Classified 19095 images
Classified 20452 images
Classified 21797 images
Classified 23182 images
Classified 24528 images
Classified 25891 images
Classified 27223 images
Classified 28606 images
Classified 29961 images
Classified 31329 images
Classified 32685 images
Classified 34061 images
Classified 35460 images
Classified 36794 images
Classified 38162 images
Classified 39532 images
Total classified images 40871


(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 tensor([5300, 6727, 2535, 4623, 2910, 4949,  770, 6259, 3092, 3706]))

## Tranining the model for classification

Now we have a much bigger data set with the labels, so we can use more digit image examples and train the more accurate model.

Here we create a new training set with 500 randomly selected images per each digit class and train the model from scratch on a new training set.

In [None]:
train_loader_digits = sort_train_loader(new_train_loader, sample_size = 500, seed = 1)

vector_labels = vector_labels[torch.randperm(vector_labels.size(0))]

digit_models = train_digit_models(epochs = 20, max_tries = 10, min_loss = 0.55, seed = 1)

Training the model for digit 0 (training set size 500)
Best loss for 0 digit: 0.153815
Training the model for digit 1 (training set size 500)
Best loss for 1 digit: 0.466718
Training the model for digit 2 (training set size 500)
Best loss for 2 digit: 0.573049
Training the model for digit 3 (training set size 500)
Best loss for 3 digit: 0.322496
Training the model for digit 4 (training set size 500)
Best loss for 4 digit: 0.485401
Training the model for digit 5 (training set size 500)
Best loss for 5 digit: 0.814577
Training the model for digit 6 (training set size 500)
Best loss for 6 digit: 0.486279
Training the model for digit 7 (training set size 500)
Best loss for 7 digit: 0.188087
Training the model for digit 8 (training set size 500)
Best loss for 8 digit: 0.079691
Training the model for digit 9 (training set size 500)
Best loss for 9 digit: 0.444722


## Testing the model

In the eval_model() function we apply 10 models to a given image and calculate cosine similarity between generated vector and the label vector associated with each model. It allows to classify the image by choosing the label with highest cosine similarity value. We run the eval_model() function against the test dataset (10K images) and also against the whole training set (60K images). Note that we trained our models on the limited set of images (500 images per each class) randomly selected from the training set. Our actual training set is only a fraction (10 * 500 / 60000) of the whole training set. It means that more than 90% of the images from training set have not seen by our models during the training phase.

In [None]:
def eval_model(dataset_loader):
  miss_classified = 0

  for data in dataset_loader:
    images, labels = data

    loss = []

    for j in range(len(digit_models)):
      out = digit_models[j](images.to(device)).view(images.size(0), 25)

      loss.append(nn.CosineSimilarity(eps=1e-6)(out, torch.tensor([vector_labels[i].tolist() for i in labels]).to(device)).tolist())

    miss_classified += (torch.tensor(loss).t().argsort()[:,-1] != labels).int().sum()

  return miss_classified

print('Test set: miss-classified {} in {} images'.format(eval_model(test_loader), test_loader.dataset.data.size(0)))

print('Train set: miss-classified {} in {} images'.format(eval_model(train_loader), train_loader.dataset.data.size(0)))

Test set: miss-classified 1 in 10000 images
Train set: miss-classified 12 in 60000 images


## Conclusion

Starting with a few labeled image examples we trained the model that allowed us to label a big number (more than 40K) of images from a training set (60K images). Then using additional labeled images we trained a new, more accurate model that demonstrated near 100% accuracy on the test set.

This approach shows how having a very few labeled examples, we can build a model that generalizes very well to unseen examples and provides an accuracy of a human level.
