In [27]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# First make sure to install timm
!pip install timm

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
from tqdm import tqdm
import numpy as np
import timm
import torchvision
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session



## Device Selection

We will opt for gpu, if it's available

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f'Using device: {device}')

Using device: cpu


## Download Dataset

We can use CIFAR10 or CIFAR100 as our dataset. Since both are really common datasets we will use the `torchvision.datasets` class to load them.

To select the dataset, modify the `use_CIFAR10` boolean variable accordingly.

We also apply some basic __preprocessing__:

1. Normalize the dataset by subtracting the _mean_ and dividing with the _std_

In [13]:
from torch.utils.data import DataLoader, random_split

use_CIFAR10 = True

# Get the precomputed mean and std
# Those are needed to normalize the dataset
# NOTE: To calculate the mean and std we have to 
# 1. calculate the sum for each channel
# 2. implement mean and variance formulas
if use_CIFAR10:
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)
else:
    raise NotImplementedError("Please compute mean, std for CIFAR100")

# dataset directory remains the same for both cases
dataset_directory = "/kaggle/working/"

# the transformation also remains the same
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.Normalize(mean, std)
])

# download the dataset
if use_CIFAR10:
    cifar_dataset = torchvision.datasets.CIFAR10(root=dataset_directory, train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root=dataset_directory, train=False, download=True, transform=transform)
else:
    cifar_dataset = torchvision.datasets.CIFAR100(root=dataset_directory, train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR100(root=dataset_directory, train=False, download=True, transform=transform)

print(f'Dataset downloaded. Total images: {len(cifar_dataset)}')

# Split the dataset into train / valildation sets
train_size = int(0.9 * len(cifar_dataset))
val_size = len(cifar_dataset) - train_size

train_dataset, val_dataset = random_split(cifar_dataset, [train_size, val_size])

# set the batch size to 64
batch_size = 64

# Create some dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

print(f'Train images: {train_size}')
print(f'Validation images: {val_size}')
print(f'Test images: {len(test_dataset)}')

Files already downloaded and verified
Files already downloaded and verified
Dataset downloaded. Total images: 50000
Train images: 45000
Validation images: 5000
Test images: 10000


## Useful Methods

Bellow we have created some methods to make the code simpler.

### Get CLS Token

By taking a look at the [documentation](https://huggingface.co/docs/timm/en/feature_extraction) for the timm library and specifically the __Feature Extraction__ section we can see that in order to get the __last hidden state__ of the model we have to use the `forward_features` method.

Specifically, this method returns the patch embeddings at the last hidden state, __before pooling is applied__. The return vector is of shape

```
(batch_size, num_patches + 1, hidden_size)
```

The __CLS Token__ is by design the __first of the patch embeddings__

For example to get the CLS Token of the first image in the batch we would have to do:

```py
model_output[0, 0, :]
```

In [43]:
def get_cls_token(model: timm.models.vision_transformer.VisionTransformer, images: torch.Tensor) -> np.array:
    # get the last hidden state
    output = model.forward_features(images)

    # for each image get the cls token
    # make sure to convert each tensor to numpy
    cls_tokens = output[:, 0, :].cpu().numpy()
    
    return cls_tokens

### Extract CLS Tokens for all images in a DataLoader

This method uses the `get_cls_token` method above to extract all cls tokens from a given dataloader.

For each image we will also need:
    
    1. it's original position (?)
    2. the label

The method returns __a dictionary__ with:
1. __key__: The original position of the image
2. __value__: A dictionary with `cls_token` and `label` keys

In [44]:
def get_dataset_cls_tokens(model: timm.models.vision_transformer.VisionTransformer, loader: torch.utils.data.dataloader.DataLoader
):

    cls_token_dictionary = {}

    for idx, (images, labels) in tqdm(enumerate(loader), desc="Calculating CLS Tokens", total=len(loader)):

        # this returns a numpy array with shape
        # (batch_size, hidden_size)
        cls_tokens = get_cls_token(model=model, images=images)

        # For each image in the batch
        for idx in range(cls_tokens.shape[0]):
            cls_token_dictionary[idx] = {
                "cls_token" : cls_tokens[idx, :],
                "label" : labels[idx]
            }

    return cls_token_dictionary

## Load the model

To load the model we will use the `timm` library.

We could also use the `transformers` library and more specific the `ViTForImageClassification`.

In [45]:
model = timm.create_model(
    "vit_tiny_patch16_224",  # Pre-trained ViT-Tiny on ImageNet-1k
    pretrained=True,        # Load pre-trained weights
    num_classes=10          # Adapt classifier head to CIFAR-10 (10 classes)
)

# We dont want to train here so we can freeze all the layers
for param in model.parameters():
    param.requires_grad = False

## Generate CLS Tokens for the Train Set

In [50]:
train_cls_tokens = get_dataset_cls_tokens(model, train_loader)

Calculating CLS Tokens: 100%|██████████| 704/704 [14:40<00:00,  1.25s/it]


## Generate CLS Tokens for the Test Set

In [51]:
test_cls_tokens = get_dataset_cls_tokens(model, test_loader)

Calculating CLS Tokens: 100%|██████████| 157/157 [03:15<00:00,  1.25s/it]
