<a href="https://colab.research.google.com/github/yuzenlin/Colab/blob/main/CIFAR10_Classification_DeiT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CIFAR10 Classification with Data-efficient image Transformers(DeiT)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader

from torchvision import datasets, transforms, models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.utils import save_image

from torchsummary import summary

import spacy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import time
import math
from PIL import Image
import glob
from IPython.display import display

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
torch.manual_seed(0)
np.random.seed(0)

In [None]:
BATCH_SIZE = 32
LR = 5e-5
NUM_EPOCHES = 25

## Preprocessing

In [None]:
mean, std = (0.5,), (0.5,)

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean, std)
                              ])

In [None]:
trainset = datasets.CIFAR10('../data/CIFAR10/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

testset = datasets.CIFAR10('../data/CIFAR10/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

100%|██████████| 170M/170M [00:05<00:00, 29.7MB/s]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-1199584384>", line 1, in <cell line: 0>
    trainset = datasets.CIFAR10('../data/CIFAR10/', download=True, train=True, transform=transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torchvision/datasets/cifar.py", line 66, in __init__
    self.download()
  File "/usr/local/lib/python3.11/dist-packages/torchvision/datasets/cifar.py", line 139, in download
    download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
  File "/usr/local/lib/python3.11/dist-packages/torchvision/datasets/utils.py", line 394, in download_and_extract_archive
    extract_archive(archive, extract_root, remove_finished)
  File "/usr/local/lib/python3.11/dist-

TypeError: object of type 'NoneType' has no len()

## Pre-training Teacher Model

In [None]:
# https://github.com/UdbhavPrasad072300/Transformer-Implementations/blob/main/pre-train/VGG16_CIFAR10.ipynb
teacher_model = torch.load("../trained_models/vgg16_cifar10.pth")

In [None]:
teacher_model.preprocess_flag = False
print(teacher_model.preprocess_flag)

## DeiT

In [None]:
from transformer_package.models import DeiT

In [None]:
image_size = 32
channel_size = 3
patch_size = 4
embed_size = 512
num_heads = 8
classes = 10
num_layers = 4
hidden_size = 512
dropout = 0.2

model = DeiT(image_size=image_size,
             channel_size=channel_size,
             patch_size=patch_size,
             embed_size=embed_size,
             num_heads=num_heads,
             classes=classes,
             num_layers=num_layers,
             hidden_size=hidden_size,
             teacher_model=teacher_model,
             dropout=dropout
            ).to(device)
model

In [None]:
for img, label in trainloader:
    img = img.to(device)
    label = label.to(device)

    print("Input Image Dimensions: {}".format(img.size()))
    print("Label Dimensions: {}".format(label.size()))
    print("-"*100)

    out, teacher_out = model(img)

    print("Student Output Dimensions: {}".format(out.size()))
    print("Teacher Output Dimensions: {}".format(teacher_out.size()))
    break

## Training with Hard Label Distillation

In [None]:
from transformer_package.loss_functions.loss import Hard_Distillation_Loss

In [None]:
criterion = Hard_Distillation_Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)

In [None]:
loss_hist = {}
loss_hist["train accuracy"] = []
loss_hist["train loss"] = []

for epoch in range(1, NUM_EPOCHES+1):
    model.train()

    epoch_train_loss = 0

    y_true_train = []
    y_pred_train = []

    for batch_idx, (img, labels) in enumerate(trainloader):
        img = img.to(device)
        labels = labels.to(device)

        preds, teacher_preds = model(img)

        loss = criterion(teacher_preds, preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        y_pred_train.extend(preds.detach().argmax(dim=-1).tolist())
        y_true_train.extend(labels.detach().tolist())

        epoch_train_loss += loss.item()

    loss_hist["train loss"].append(epoch_train_loss)

    total_correct = len([True for x, y in zip(y_pred_train, y_true_train) if x==y])
    total = len(y_pred_train)
    accuracy = total_correct * 100 / total

    loss_hist["train accuracy"].append(accuracy)

    print("-------------------------------------------------")
    print("Epoch: {} Train mean loss: {:.8f}".format(epoch, epoch_train_loss))
    print("       Train Accuracy%: ", accuracy, "==", total_correct, "/", total)
    print("-------------------------------------------------")

## Test

In [None]:
plt.plot(loss_hist["train accuracy"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

In [None]:
plt.plot(loss_hist["train loss"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

In [None]:
with torch.no_grad():
    model.eval()

    y_true_test = []
    y_pred_test = []

    for batch_idx, (img, labels) in enumerate(testloader):
        img = img.to(device)
        label = label.to(device)

        preds, _ = model(img)

        y_pred_test.extend(preds.detach().argmax(dim=-1).tolist())
        y_true_test.extend(labels.detach().tolist())

    total_correct = len([True for x, y in zip(y_pred_test, y_true_test) if x==y])
    total = len(y_pred_test)
    accuracy = total_correct * 100 / total

    print("Test Accuracy%: ", accuracy, "==", total_correct, "/", total)