**MOUNT DRIVE (IF FOR EXAMPLE YOU WANT TO READ/WRITE WEIGHTS FROM MyDrive):**

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


**Change current working directory from `content` to the directory of the location of this script in `/content/drive/MyDrive/img_cls_to_vit`. This allows the hard-coded relative paths from local machine set up to work here in Colab:**

In [2]:
import os
os.chdir('/content/drive/MyDrive/img_cls_to_vit')
os.getcwd()

'/content/drive/MyDrive/img_cls_to_vit'

**INSTALL ALLOWED LIBRARIES (~80 secs):**

In [3]:
from time import time
start = time()
!pip install torch
!pip install torchvision
!pip install pillow
!pip install tqdm
!pip install transformers # will need to remove later
print(f'Pip installed torch, torchvision, pillow and tqdm in {round(time() - start, 4)} secs')

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m58.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m49.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m77.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [10]:
# !cat /etc/*release

**IMPORT LIBRARIES AND SET THE PROCESSOR DEVICE:**

In [1]:
# train script
from time import time
from tqdm import tqdm
import torch
from torch import nn
import torchvision
import torchvision.transforms as tv_transforms
import torchvision.datasets as tv_datasets
from PIL import Image
from mix_up import MixUp
# import multiprocessing

def set_device():
    """
    Set device: to either Cuda (GPU), MPS (Apple Silicon GPU), or CPU
    """
    device = torch.device(
        'cuda'
        if torch.cuda.is_available()
        else 'mps'
        if torch.backends.mps.is_available()
        else 'cpu'
    )
    print(f'Using {device} device')
    return device
device = set_device()

Using cpu device


**IMPORT `CIFAR-10` IMAGE TRAINING DATASET, TRANSFORM, LOAD TO DATALOADER & ITERATOR, LOOK AT EXAMPLE (~ 3 sec)**<br>
(The CIFAR-10 dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. It is divided into 50,000 training images and 10,000 test images. The 10 classes are: plane, car, bird, cat, deer, dog, frog, horse, ship & truck.)

In [2]:
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)
pretrained_transforms = pretrained_vit_weights.transforms()

for params in pretrained_vit.parameters():
    params.requires_grad=False

pretrained_vit.heads = nn.Sequential(nn.Linear(in_features=768, out_features=10))

In [23]:
# !pip install torch-summary
from torchsummary import summary
summary(pretrained_vit, (3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
├─Conv2d: 1-1                                 [-1, 768, 14, 14]         (590,592)
├─Encoder: 1-2                                [-1, 197, 768]            --
|    └─Dropout: 2-1                           [-1, 197, 768]            --
|    └─Sequential: 2-2                        [-1, 197, 768]            --
|    |    └─EncoderBlock: 3-1                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-2                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-3                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-4                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-5                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-6                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-7                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock:

Layer (type:depth-idx)                        Output Shape              Param #
├─Conv2d: 1-1                                 [-1, 768, 14, 14]         (590,592)
├─Encoder: 1-2                                [-1, 197, 768]            --
|    └─Dropout: 2-1                           [-1, 197, 768]            --
|    └─Sequential: 2-2                        [-1, 197, 768]            --
|    |    └─EncoderBlock: 3-1                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-2                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-3                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-4                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-5                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-6                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock: 3-7                 [-1, 197, 768]            (7,087,872)
|    |    └─EncoderBlock:

In [3]:
# # ViT: transform cifar-10 dataset for 'facebook/deit-...-patch16-224'
transform = tv_transforms.Compose([
    tv_transforms.Resize((224, 224)),
    tv_transforms.ToTensor(),
    tv_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# ViT: Apply transformation to CIFAR10 dataset fr training (50,000 images):
# trainset = tv_datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainset = tv_datasets.CIFAR10(root='./data', train=True, download=True, transform=pretrained_transforms)
testset = tv_datasets.CIFAR10(root='./data', train=False, download=True, transform=pretrained_transforms)

assert len(trainset) == 50000
batch_size = 20
# cpu_count_mp = multiprocessing.cpu_count()
# print(f"Number of CPU threads according to multiprocessing: {cpu_count_mp}")
# Load in Dataloader:
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=cpu_count_mp)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
# List all 10 classes you can expect to find in CIFAR-10:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

dataiter = iter(trainloader)  # Create iterator of images
images, labels = next(dataiter)  # take one example batch of images out to sanity check.
# Save example images to jpg:
im = Image.fromarray((torch.cat(images.split(1, 0), 3)
                      .squeeze() / 2 * 255 + .5 * 255)
                     .permute(1, 2, 0).numpy().astype('uint8'))
im.save('train_pt_images_vit.jpg')
print('train_pt_images_vit.jpg saved.')
print('Ground truth labels:' + ' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))

Files already downloaded and verified
train_pt_images_vit.jpg saved.
Ground truth labels:truck horse  ship plane  bird   car  bird  bird  ship  frog  ship   cat   cat   dog horse  deer horse  frog horse   dog


**(ONLY NEED TO DO THIS ONCE)<br>
LOAD PRE-TRAINED DeiT TINY ViT MODEL FROM HUGGINGFACE AND SAVED TO LOCAL `.PT` (~1 sec) FILE:**

In [None]:
# ViT: instantiate pretrained model
# !pip install transformers
# from transformers import DeiTConfig, DeiTModel
# start = time()
# from google.colab import userdata
# userdata.get('shahin_HF_datasets')

# The model is initially installed from HF via `transformers` library,
# but subsequently saved to `saved_models/pretrained` dir as `transformers`
# is not allowed.

# # Init DeiT `deit-tiny-patch16-224` style configuration, (with random weights)
# config = DeiTConfig.from_pretrained('facebook/deit-tiny-patch16-224')
# config = DeiTConfig.from_pretrained('facebook/deit-small-patch16-224')
# net = DeiTModel(config)
# # torch.save(net, 'saved_models/pretrained/deit_tiny_vit.pt')
# torch.save(net, 'saved_models/pretrained/deit_small_vit.pt')
# # Save model architecture to file:
# with open('network_pt_deit_small.py', 'w') as f:
#     f.write(str(net))
# print(f'Pretrained DeiT model uploaded in {round(time() - start, 4)} secs')

In [None]:
import os
os.chdir('/content/drive/MyDrive/img_cls_to_vit')

**LOAD PRE-TRAINED DeiT TINY ViT MODEL FROM LOCAL .PT FILE:**

In [8]:
# net_tiny = torch.load(f='saved_models/pretrained/deit_tiny_vit.pt')
# net_tiny = net_tiny.to(device)
# print(f'Local pretrained DeiT model uploaded in {round(time() - start, 2)} secs')
# print(f'DeiT tiny network architecture {net_tiny}')

Local pretrained DeiT model uploaded in 76.12 secs
DeiT tiny network architecture DeiTModel(
  (embeddings): DeiTEmbeddings(
    (patch_embeddings): DeiTPatchEmbeddings(
      (projection): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): DeiTEncoder(
    (layer): ModuleList(
      (0-11): 12 x DeiTLayer(
        (attention): DeiTAttention(
          (attention): DeiTSelfAttention(
            (query): Linear(in_features=192, out_features=192, bias=True)
            (key): Linear(in_features=192, out_features=192, bias=True)
            (value): Linear(in_features=192, out_features=192, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): DeiTSelfOutput(
            (dense): Linear(in_features=192, out_features=192, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): DeiTIntermediate(
          (dense): Linear(in_f

In [13]:
# # net = torch.load(f='saved_models/pretrained/deit_tiny_vit.pt')
# net_small = torch.load(f='saved_models/pretrained/deit_small_vit.pt')
# net_small = net_small.to(device)
# print(f'Local pretrained DeiT model uploaded in {round(time() - start, 2)} secs')
# print(f'DeiT network architecture {net_small}')

Local pretrained DeiT model uploaded in 67.3 secs
DeiT network architecture DeiTModel(
  (embeddings): DeiTEmbeddings(
    (patch_embeddings): DeiTPatchEmbeddings(
      (projection): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): DeiTEncoder(
    (layer): ModuleList(
      (0-11): 12 x DeiTLayer(
        (attention): DeiTAttention(
          (attention): DeiTSelfAttention(
            (query): Linear(in_features=384, out_features=384, bias=True)
            (key): Linear(in_features=384, out_features=384, bias=True)
            (value): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): DeiTSelfOutput(
            (dense): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): DeiTIntermediate(
          (dense): Linear(in_feature

**INITIALISE CROSS ENTROPY LOSS FUNCTION AND SGD OPTIMISER:**

In [4]:
criterion = torch.nn.CrossEntropyLoss()
criterion = criterion.to(device)
# optimiser = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimiser = torch.optim.Adam(pretrained_vit.parameters(), lr=0.001)

**LOAD PARTIALLY TRAINED MODEL IF PRESENT:**

In [None]:
# # Loading model and optimiser state:
# checkpoint = torch.load('checkpoint.pth')
# net.load_state_dict(checkpoint['model_state_dict'])
# optimiser.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']

In [None]:
def _calc_accuracy(predicted_class, ground_truth):
    y_pred_class = torch.argmax(torch.softmax(predicted_class, dim=1), dim=1)


**TRAIN MODEL FOR 20 EPOCHS:**

In [27]:
# from transformers import AutoFeatureExtractor
# feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-small-patch16-224')
start = time()
epochs = 20

for epoch in tqdm(range(epochs)):

    pretrained_vit.train()
    loss, cumulative_loss = 0.0, 0.0
    # for i, data in enumerate(trainloader, 0):
    #         # data is list of [inputs, labels]
    #         inputs, labels = data
    print(f'\nEpoch number {epoch}')

    for i, data in enumerate(trainloader):
        # print(f'i {i}')
        inputs, labels = data[0].to(device), data[1].to(device)
        optimiser.zero_grad()  # zero parameter gradients
        predicted_class = pretrained_vit(inputs)
        # y_pred_class = torch.argmax(torch.softmax(predicted_class, dim=1), dim=1)
        # # inputs = feature_extractor(images=inputs, return_tensors="pt")
        # # outputs = net_tiny(inputs)  # forward
        # # logits = outputs.logits # doesn't have `logits`
        # # predicted_class_idx = logits.argmax(-1).item()
        # # print("Predicted class:", net.config.id2label[predicted_class_idx])
        # # outputs = outputs.pooler_output  # get logits out. Shape (20, 192)
        # print(f'predicted_class {predicted_class}')
        # print(f'predicted_class.shape {predicted_class.shape}')
        # print(f'type(y_pred_class[0]) {type(y_pred_class[0])}')
        loss = criterion(predicted_class, labels)
        loss.backward()  # backward
        cumulative_loss += loss.item()
        optimiser.step()  # optimise
        _calc_accuracy(predicted_class, ground_truth=labels)

        if i % 2000 == 1999:  # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, cumulative_loss / 2000))
            running_loss = 0.0

    # # Save model state at each epoch:
    # torch.save({
    #             'epoch': epoch,
    #             'model_state_dict': net.state_dict(),
    #             'optimizer_state_dict': optimiser.state_dict(),
    #             'loss': loss,
    #             }, f'checkpoint{epoch}.pth')

    print(f'loss={loss.item()} at epoch={epoch}')

print(f'Completed training for {epochs} epochs.')
print(f'Training model for {epochs} epochs took {round(((time() - start) / 60), 4)} mins')

tuned_model_dirs = 'saved_models/pretrained_finetuned'
if not os.path.exists(tuned_model_dirs): os.makedirs(tuned_model_dirs)
# deit_tiny_tuned_path = os.path.join(tuned_model_dirs, 'deit_tiny_vit_tuned.pt')
deit_small_tuned_path = os.path.join(tuned_model_dirs, 'deit_small_vit_tuned.pt')
# torch.save(net.state_dict(), deit_tiny_tuned_path)
torch.save(pretrained_vit.state_dict(), deit_small_tuned_path)
print('Trained model saved.')

  0%|          | 0/20 [00:00<?, ?it/s]


Epoch number 0
labels tensor([3, 3, 7, 8, 9, 8, 4, 8, 7, 9, 2, 8, 2, 6, 2, 0, 9, 2, 1, 5])
predicted_class.shape torch.Size([20, 10])
type(labels[0]) <class 'torch.Tensor'>
labels tensor([3, 7, 3, 1, 7, 4, 8, 4, 7, 8, 6, 5, 6, 1, 1, 9, 7, 4, 1, 0])
predicted_class.shape torch.Size([20, 10])
type(labels[0]) <class 'torch.Tensor'>
labels tensor([0, 5, 7, 5, 7, 7, 7, 2, 9, 5, 1, 4, 5, 0, 6, 8, 3, 5, 6, 4])
predicted_class.shape torch.Size([20, 10])
type(labels[0]) <class 'torch.Tensor'>
labels tensor([4, 8, 1, 6, 6, 2, 3, 8, 1, 4, 1, 4, 7, 2, 9, 0, 0, 1, 8, 7])
predicted_class.shape torch.Size([20, 10])
type(labels[0]) <class 'torch.Tensor'>
labels tensor([9, 0, 4, 1, 0, 9, 3, 3, 6, 9, 5, 2, 5, 9, 0, 7, 6, 0, 0, 3])
predicted_class.shape torch.Size([20, 10])
type(labels[0]) <class 'torch.Tensor'>
labels tensor([3, 9, 5, 6, 7, 5, 0, 2, 8, 9, 4, 8, 1, 2, 5, 8, 9, 2, 8, 1])
predicted_class.shape torch.Size([20, 10])
type(labels[0]) <class 'torch.Tensor'>
labels tensor([8, 5, 5, 1, 1, 2, 5, 

  0%|          | 0/20 [04:45<?, ?it/s]


KeyboardInterrupt: 

**Save fine-tuned pretrained DeiT-tiny model:**

In [None]:
tuned_model_dirs = 'saved_models/pretrained_finetuned'
if not os.path.exists(tuned_model_dirs): os.makedirs(tuned_model_dirs)
deit_tiny_tuned_path = os.path.join(tuned_model_dirs, 'deit_tiny_vit_tuned.pt')
torch.save(net.state_dict(), deit_tiny_tuned_path))
print('Trained model saved.')