In this task you will use modified MNIST dataset where some images are randomly relabelled.
You are given a pipeline that trains a simple convolutional neural network on this dataset.
The purpose of this task is to show you how to create data visualizations (e.g. by creating image embeddings) and use them to identify problems with data like incorrect labels.

Tasks:
1. Run training and check that the model achieves accuracy below 70% on the test set.
2. Use clusterfun visualizations to identify mislabelled elements of the train dataset.
3. Relabel these elements - make the labels correct (use provided `FashionMNISTWithRandomModifications.relabel()` method).
4. Train the model once again on the updated train data. With correct labels it should achieve accuracy above 95% on the test set.


In [None]:
!pip install -q clusterfun umap-learn

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import pandas as pd
import umap
import clusterfun as clt
from PIL import Image
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

In [None]:
from IPython.display import Javascript

def show_port(port, height=700):
    """
    Helper function for displaying iframe with clusterfun interactive visualizations.
    """
    display(Javascript("""
        (async ()=>{
            fm = document.createElement('iframe')
            fm.src = await google.colab.kernel.proxyPort(%s)
            fm.width = '95%%'
            fm.height = '%d'
            fm.frameBorder = 0
            fm.style.background = 'white'
            fm.scrolling = 'no'
            document.body.append(fm)
        })();
    """ % (port, height) ))

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv_layers = torch.nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(),
            nn.Flatten(),
        )

        self.head = torch.nn.Sequential(
            nn.Linear(7 * 7 * 32, 200),
            nn.BatchNorm1d(200),
            nn.ReLU(),
            nn.Linear(200, 10),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.head(x)
        return x


class MNISTWithRandomModifications(MNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(MNISTWithRandomModifications, self).__init__(root, train, transform, target_transform, download)
        self._extract_images()
        self.dataframe = self._create_dataframe()
        if train:
            self._change_labels_randomly()

    def _extract_images(self):
        if not os.path.exists(self._img_folder):
            os.makedirs(self._img_folder)
            for i, img in enumerate(self.data):
                img_path = os.path.join(self._img_folder, f'img{i}.png')
                Image.fromarray(img.numpy()).save(img_path)

    def _change_labels_randomly(self):
        images_per_label = len(self) // 10
        # For simplicity we randomly change labels in continuous segments based on _id.
        # Each segment contains samples with the same label and a new label is also
        # identical for all elements of the segment.
        num_random_segments = 3
        random_segment_length = 5000

        random_labels = np.random.choice(10, num_random_segments, replace=False)
        new_labels = np.zeros(num_random_segments).astype(np.int64)
        for i, label in enumerate(random_labels):
            possible_new_labels = list(set(range(10)) - {label})
            new_labels[i] = np.random.choice(possible_new_labels, 1).item()

        segment_start_ids = np.zeros(num_random_segments).astype(np.int64)
        for i, random_label in enumerate(random_labels):
            label_ids = self.dataframe[self.dataframe['label'] == random_label]['_id']
            num_elements_with_label = max(label_ids) - min(label_ids)
            random_idx = np.random.choice(num_elements_with_label - random_segment_length, 1).item()
            segment_start_ids[i] = min(label_ids) + random_idx

        segment_end_ids = segment_start_ids + random_segment_length
        relabel_ids = np.concatenate([
            np.arange(start, end)
            for start, end in zip(segment_start_ids, segment_end_ids)
        ])
        relabel_indices = self.dataframe[self.dataframe['_id'].isin(relabel_ids)].index
        new_labels_repeated = new_labels.repeat(random_segment_length)
        self.dataframe.loc[relabel_indices, 'label'] = new_labels_repeated
        self.targets[relabel_indices] = torch.tensor(new_labels_repeated)

    def _create_dataframe(self):
        df = pd.DataFrame({
            'img_path': [
                os.path.abspath(os.path.join(self._img_folder, f'img{i}.png'))
                for i in range(len(self))
            ],
            'label': self.targets,
            'pred': pd.NA,
        })
        # Reorder data so that observations with the same label are placed in continuous
        # segments with respect to their _id.
        # E.g. samples with label 0 have ids 0-5999, with label 1 have ids 6000-11999 etc.
        # This is needed to make relabelling exercise simple.
        # Note that we do not change dataframe index. Therefore to find element with given id
        # in self.data we need to do:
        # self.data[self.dataframe[self.dataframe['_id' = id]].index.item()].
        df = df.sort_values('label')
        df['_id'] = range(len(self))
        return df

    @property
    def _img_folder(self):
        return os.path.join(
            self.root,
            self.__class__.__name__,
            'img',
            'train' if self.train else 'test'
        )

    def relabel(self, ids: list[int], new_labels: list[int]):
        """
        Use this function to relabel data after identifying elements with incorrect labels.
        :param ids: Values of '_id' column of elements from the dataset that should be relabelled.
        :param new_labels: New labels for selected dataset elements. Labels must be integers in range 0-9.
        """

        assert len(ids) == len(new_labels), 'There should be exactly one label for each id'
        relabel_indices = self.dataframe[self.dataframe['_id'].isin(ids)].index
        self.targets[relabel_indices] = torch.tensor(new_labels)
        self.dataframe.loc[relabel_indices, 'label'] = new_labels

    def update_pred_in_dataframe(self, batch_idx: int, batch_size: int, pred: torch.Tensor):
        """
        Helper function for updating pred in dataframe stored by dataset.
        :param pred: Tensor 1d with predicted labels. Labels must be integers in range 0-9.
        """
        assert len(pred.shape) == 1, 'pred should be 1d tensor'
        df_indices = range(
            batch_idx * batch_size,
            # Last batch may be smaller than batch_size if len(self) % batch_size != 0
            min((batch_idx + 1) * batch_size, len(self))
        )
        train_loader.dataset.dataframe.loc[df_indices, 'pred'] = pred.numpy()

    def generate_embeddings(self, seed=1):
        """
        Generetes embeddings of images using UMAP.
        """
        umap_reducer = umap.UMAP(random_state=seed)
        umap_res = umap_reducer.fit_transform(
            ((self.data - self.data.min()) / self.data.max()).view(len(self), -1).numpy()
        )
        assert umap_res.shape == (len(self), 2), f'Incorrect UMAP result shape: {umap_res.shape}'
        self.dataframe.loc[range(len(self)), 'x'] = umap_res[:, 0]
        self.dataframe.loc[range(len(self)), 'y'] = umap_res[:, 1]

In [None]:
def train(model, device, train_loader, optimizer, epoch, log_interval):
    model.train()
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        pred = output.argmax(dim=1)
        train_loader.dataset.update_pred_in_dataframe(batch_idx, train_loader.batch_size, pred)
        correct += pred.eq(target.view_as(pred)).sum().item()
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            _, _, image_width, image_height = data.size()
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )
    print(
        "Train accuracy: {}/{} ({:.0f}%)".format(
            correct, len(train_loader.dataset),
            100.0 * correct / len(train_loader.dataset),
        )
    )


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1)
            test_loader.dataset.update_pred_in_dataframe(batch_idx, test_loader.batch_size, pred)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)

    print(
        "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )

In [None]:
batch_size = 512
test_batch_size = 1000
epochs = 5
lr = 1e-3
use_cuda = True
seed = 1
log_interval = 20

In [None]:
use_cuda = use_cuda and torch.cuda.is_available()

torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {"batch_size": batch_size}
test_kwargs = {"batch_size": test_batch_size}
if use_cuda:
    cuda_kwargs = {"num_workers": 1, "pin_memory": True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

In [None]:
data_dir = "../data"
transform = transforms.ToTensor()

train_dataset = MNISTWithRandomModifications(
    data_dir,
    train=True,
    download=True,
    transform=transform,
)
test_dataset = MNISTWithRandomModifications(
    data_dir,
    train=False,
    download=True,
    transform=transform,
)

train_loader = DataLoader(train_dataset, **train_kwargs)
test_loader = DataLoader(test_dataset, **test_kwargs)

In [None]:
model = ConvNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, log_interval)
    test(model, device, test_loader)

## Dataset exploration

You can notice that accuracy achieved on the test data is not very high. This is caused by incorrect labels assigned to some of the images. To identify mislabelled elements we will visualize data.

In order to visualize data we need to generate embeddings. You can do this just by calling `FashionMNISTWithRandomModifications.generate_embeddings()` method. It generates embeddings with UMAP which is one of the techniques used for dimensionality reduction. (Other techniques commonly used in such cases include PCA or t-SNE.)

For visualizations we will use [clusterfun](https://clusterfun.app/) library. It allows us to create a few interactive visualizations. You can find more details in documentation on their website. For visualization of embeddings we will use scatter plot.

Because clusterfun does not support Colab we use small helper function `show_port()` to display plots. You need to copy port number printed by clusterfun and pass it to this function.

Explore embeddings visualization and try to find mislabelled examples.

Hint: To make this scenario simple we randomly change labels in 3 continuous segments based on _id. Each segment contains 5000 samples with the same label and a new label is also identical for all elements of the segment.



In [None]:
train_dataset.generate_embeddings(seed=seed)

In [None]:
clt.scatter(train_dataset.dataframe, x='x', y='y', color='label', media='img_path')

Serving plot on http://localhost:39153


PosixPath('/root/.cache/clusterfun/0ad93bd9-38a4-454d-b608-a949dbe44fd3')

In [None]:
# Copy port number from string printed above
# Serving plot on http://localhost:<port_number>
show_port(port_number)

<IPython.core.display.Javascript object>

## Identify and relabel data

Then check how visualization changed.

In [None]:
############## TODO #############
# Identify incorrect labels on visualizations above and change them.
# Use train_dataset.relabel()

In [None]:
############## TODO #############
# Visualize data after changes
# You DO NOT need to recompute embeddings - images are the same, only labels changed

In [None]:
show_port(port_number)

## Retrain the model

After changing labels to the correct ones the model should achieve above 95% accuracy on the test set.

In [None]:
############## TODO #############
# Train model using changed data