## Installing the relevant libraries.

In [None]:
!pip install tqdm
!pip install transformers datasets
!pip install torchinfo


## Importing the packages

In [None]:
import numpy as np
import torch
import torch.nn as nn
from datasets import Array3D, ClassLabel, Features, load_dataset
from matplotlib import pyplot
from numpy import inf
from sklearn.utils.class_weight import compute_class_weight
from torchinfo import summary
from tqdm import tqdm
from transformers import AdamW, ViTFeatureExtractor, ViTModel

## Downloading the data and preparing the train, validation and test datasets

Here we import a small portion of CIFAR-10 dataset

In [None]:
# load cifar10 
train_ds, test_ds = load_dataset('cifar10', split=['train[:5000]', 'test[:2000]'])
# split up training into training + validation
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

## Feature Extractor

We are using `ViTFeatureExtractor`. This feature extractor will  resize/rescale the images to the same resolution (224x224) and normalize them across the RGB channels with mean (0.5, 0.5, 0.5) and standard deviation (0.5, 0.5, 0.5).

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

In [None]:
chk = train_ds[67]
a = np.array(chk['img'])
print(a.shape)
cats = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


print(cats[chk['label']])
pyplot.imshow(a, cmap=pyplot.get_cmap('gray'))

## Preprocess Images
The preprocess_images function is used to process each image in the dataset. These processed images will be fed to the model

In [None]:
def preprocess_images(examples):
    # get batch of images
    images = examples['img']
    # convert to list of NumPy arrays of shape (C, H, W)
    images = [np.array(image, dtype=np.uint8) for image in images]
    images = [np.moveaxis(image, source=-1, destination=0) for image in images]
    # preprocess and add pixel_values
    inputs = feature_extractor(images=images)
    examples['pixel_values'] = inputs['pixel_values']
    return examples

HuggingFace Datasets .map(function, batched=True) functionality is used apply the preprocess_images function on every item in the dataset

In [None]:
# we need to define the features ourselves as both the img and pixel_values have a 3D shape 
features = Features({
    'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']),
    'img': Array3D(dtype="int64", shape=(3,32,32)),
    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
})

preprocessed_train_ds = train_ds.map(preprocess_images, batched=True, features=features)
preprocessed_val_ds = val_ds.map(preprocess_images, batched=True, features=features)
preprocessed_test_ds = test_ds.map(preprocess_images, batched=True, features=features)

In [None]:
# set format to PyTorch
preprocessed_train_ds.set_format('torch', columns=['pixel_values', 'label'])
preprocessed_val_ds.set_format('torch', columns=['pixel_values', 'label'])
preprocessed_test_ds.set_format('torch', columns=['pixel_values', 'label'])

In [None]:
preprocessed_train_ds

## Preparing the train, validation and test data loaders

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# create dataloaders
train_batch_size = 10
eval_batch_size = 10
train_dataloader = torch.utils.data.DataLoader(preprocessed_train_ds, batch_size=train_batch_size, shuffle=True, num_workers=2)
val_dataloader = torch.utils.data.DataLoader(preprocessed_val_ds, batch_size=eval_batch_size, num_workers=2)
test_dataloader = torch.utils.data.DataLoader(preprocessed_test_ds, batch_size=eval_batch_size, num_workers=2)
batch = next(iter(train_dataloader))

In [None]:
batch

In [None]:
assert batch['pixel_values'].shape == (train_batch_size, 3, 224, 224)
assert batch['label'].shape == (train_batch_size,)

## Define the model

Here we are using Vision Transformer (ViT) model pre-trained on ImageNet-21k (14 million images, 21,843 classes) at resolution 224x224.

Images are presented to the model as a sequence of fixed-size patches (resolution 16x16), which are linearly embedded. One also adds a [CLS] token to the beginning of a sequence to use it for classification tasks. One also adds absolute position embeddings before feeding the sequence to the layers of the Transformer encoder.

https://huggingface.co/google/vit-base-patch16-224-in21k

In [None]:
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
for param in vit_model.parameters():
      param.requires_grad = False

In [None]:
class ViTForImageClassification(nn.Module):
    def __init__(self, num_labels=10):
        super(ViTForImageClassification, self).__init__()
        self.vit = vit_model
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
        self.num_labels = num_labels

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        output = self.dropout(outputs.last_hidden_state[:,0])
        logits = self.classifier(output)
        return logits

In [None]:
model = ViTForImageClassification()
model = model.to(device)
# summary(model)

## Computing the class weights to handle the data imbalance

In [None]:
#compute the class weights
class_wts = compute_class_weight("balanced", np.unique(preprocessed_train_ds['label']), preprocessed_train_ds['label'].tolist())
# print(class_wts)

# # convert class weights to tensor
weights= torch.tensor(class_wts,dtype=torch.float)
weights = weights.to(device)
# # loss function
cross_entropy = nn.CrossEntropyLoss(weight=weights)

## Setting the Optimizer and Epochs

In [None]:
# define the optimizer
optimizer = AdamW(model.parameters(), lr = 1e-3)

# number of training epochs
epochs = 50

## Model Train function
It is to train the train dataloader

In [None]:
# function to train the model
def train():
  
    model.train()
    total_loss = 0

    # empty list to save model predictions
    total_preds=[]

    # iterate over batches
    for step,batch in enumerate(train_dataloader):

        # progress update after every 50 batches.
        if step % 50 == 0 and not step == 0:
            print('  Batch {:>5,}  of  {:>5,}.'.format(step,    len(train_dataloader)))
        
        # push the batch to gpu
        lbl, pix = batch.items()
        lbl, pix = lbl[1].to(device), pix[1].to(device)
        
        # get model predictions for the current batch
        preds = model(pix)
       
        # compute the loss between actual and predicted values
        loss = cross_entropy(preds, lbl)
        
        # add on to the total loss
        total_loss = total_loss + loss.item()
        
        # backward pass to calculate the gradients
        loss.backward()
        
        # clip the the gradients to 1.0. It helps in preventing the    exploding gradient problem
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # update parameters
        optimizer.step()
        
        # clear calculated gradients
        optimizer.zero_grad()  
        preds=preds.detach().cpu().numpy()
        
        # append the model predictions
        total_preds.append(preds)
    # compute the training loss of the epoch
    avg_loss = total_loss / len(train_dataloader)
  
    total_preds  = np.concatenate(total_preds, axis=0)
    
    #returns the loss and predictions
    return avg_loss, total_preds



## Model Eval function
It is to evaluate the validation dataloader

In [None]:
def eval():
    total_loss = 0
    model.eval() # prep model for evaluation
    for step,batch in enumerate(val_dataloader):
        lbl, pix = batch.items()
        lbl, pix = lbl[1].to(device), pix[1].to(device)
        
        # forward pass: compute predicted outputs by passing inputs to the model
        preds = model(pix)
        # calculate the loss
        loss = cross_entropy(preds, lbl)
        total_loss += loss.item()
    
    return total_loss / len(val_dataloader)

## Training the model 

Training and checking the training loss on train data loader and validation loss on validation data loader 

In [None]:
min_loss = inf
es = 0
for epoch in range(epochs):
     
    print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))
    
    # Train model
    train_loss, _ = train()
    val_loss = eval()
    
    # Early Stopping
    if val_loss < min_loss:
        min_loss = val_loss
        es = 0
    else:
        es += 1
        if es > 4:
            print("Early stopping with train_loss: ", train_loss, "and val_loss for this epoch: ", val_loss, "...")
            break
    
    # it can make your experiment reproducible, similar to set  random seed to all options where there needs a random seed.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f'\n Training Loss: {train_loss:.3f}')
    print(f'\n Validation Loss: {val_loss:.3f}')

## Save the model weights

In [None]:
# torch.save(model.state_dict(), '/kaggle/working/model')

## Load the model weights

In [None]:
# model = ViTForImageClassification()
# model.load_state_dict(torch.load('/kaggle/working/model'), strict=False)

# # push the model to GPU
# model = model.to(device)
# summary(model)

## Testing the model on Test DataLoader

In [None]:
def eval():
    model.eval()
    y_pred = []
    y_true = []
    with torch.no_grad():
        for step, batch in tqdm(enumerate(test_dataloader), total = len(test_dataloader)):
            lbl, pix = batch.items()
            lbl, pix = lbl[1].to(device), pix[1].to(device)

            outputs = model(pix)
            outputs = torch.argmax(outputs, axis=1)
            y_pred.extend(outputs.cpu().detach().numpy())
            y_true.extend(lbl.cpu().detach().numpy())
            
    return y_pred, y_true

y_pred, y_true = eval()

In [None]:
correct = np.array(y_pred) == np.array(y_true)
accuracy = correct.sum() / len(correct)
print("Accuracy of the model", accuracy)