https://github.com/lucidrains/vit-pytorch

In [1]:
!nvidia-smi

Sun Sep 12 12:48:09 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 471.11       Driver Version: 471.11       CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
| 21%   51C    P5    N/A /  75W |    278MiB /  4096MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
! pip -q install vit_pytorch linformer

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, f1_score
import seaborn as sns

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms, utils
import torch.optim as optim
from torch.optim import lr_scheduler

import time
import os
import zipfile
from copy import deepcopy

from vit_pytorch.distill import DistillableViT, DistillWrapper

%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
torch.manual_seed(0)

<torch._C.Generator at 0x1ad706632b0>

In [5]:
NUM_OF_CLASSES = 8  # there are 8 classes in total
BATCH_SIZE = 64     # batch zize

## Galaxy Zoo 2 Datasets

In [6]:
## Custom Galaxy Zoo 2 Dataset
class GalaxyZooDataset(Dataset):
    """Galaxy Zoo Dataset"""

    def __init__(self, csv_file, images_dir, transform=None):
        """
        Args:
            csv_file (string): path to the label csv
            images_dir (string): path to the dir containing all images
            transform (callable, optional): transform to apply
        """
        self.labels_df = pd.read_csv(csv_file)
        self.labels_df = self.labels_df[['galaxyID', 'label1']].copy()

        self.images_dir = images_dir
        self.transform = transform
    
    def __len__(self):
        """
        Returns the size of the dataset
        """
        return len(self.labels_df)

    def __getitem__(self, idx):
        """
        Get the idx-th sample.
        Outputs the image (channel first) and the true label
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # galaxy ID
        galaxyid = self.labels_df.iloc[idx, 0].astype(str)
        # path of the image
        image_path = os.path.join(self.images_dir, galaxyid + '.jpg')
        # read the image
        image = Image.open(image_path)
        # apply transform (optional)
        if self.transform is not None:
            image = self.transform(image)
        # read the true label
        label = int(self.labels_df.iloc[idx, 1])

        return image, label

## Data Augmentation Transforms

In [7]:
def create_data_transforms(is_for_inception=False):
    """
    Create Pytorch data transforms for the GalaxyZoo datasets.
    Args:
        is_for_inception (bool): True for inception neural networks
    Outputs:
        train_transform: transform for the training data
        test_transform: transform for the testing data
    """
    if is_for_inception:
        input_size = 299
    else:
        input_size = 224

    # transforms for training data
    train_transform = transforms.Compose([transforms.CenterCrop(input_size),
                                          transforms.RandomRotation(90),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomVerticalFlip(),
                                          transforms.RandomResizedCrop(input_size, scale=(0.8, 1.0), ratio=(0.99, 1.01)),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # transforms for validation data
    valid_transform = transforms.Compose([transforms.CenterCrop(input_size),
                                          transforms.RandomRotation(90),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomVerticalFlip(),
                                          transforms.RandomResizedCrop(input_size, scale=(0.8, 1.0), ratio=(0.99, 1.01)),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])  

    # transforms for test data
    test_transform = transforms.Compose([transforms.CenterCrop(input_size),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    
    return train_transform, valid_transform, test_transform

In [8]:
"""
Data Loader
"""
# the batch size
BATCH_SIZE = 64

# create transforms
train_transform, valid_transform, test_transform = create_data_transforms(is_for_inception=False)

# create datasets
data_train = GalaxyZooDataset('gz2_train.csv', 'images_train', train_transform)
data_valid = GalaxyZooDataset('gz2_valid.csv', 'images_valid', valid_transform)
data_test = GalaxyZooDataset('gz2_test.csv', 'images_test', test_transform)

# dataloaders
train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(data_valid, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(data_test, batch_size=BATCH_SIZE, shuffle=True)

# check the sizes
print("**Dataloaders**")
print("Number of training data: {} ({} batches)".format(len(data_train), len(train_loader)))
print("Number of validation data: {} ({} batches)".format(len(data_valid), len(valid_loader)))
print("Number of test data: {} ({} batches)".format(len(data_test), len(test_loader)))
print("===============================")

**Dataloaders**
Number of training data: 99808 (1560 batches)
Number of validation data: 24952 (390 batches)
Number of test data: 31191 (488 batches)


## Training distiller

In [9]:
def train_distiller(model, num_epochs, optimizer, scheduler, print_every=1, early_stop_epochs=10):
    """
    Train the distiller
    Args:
        model: distiller wrapper
        num_epochs: number of epochs to train
        criterion: the loss function object
        optimizer: the optimizer
        scheduler: the learning rate decay scheduler
        print_every: print the information every X epochs
        early_stop_epochs: early stopping if the model doesn't improve after X epochs
        is_for_inception: True if the model is an inception model
    """

    for epoch in range(num_epochs):
        epoch_start_time = time.time() # start time

        """
        Train
        """
        model.train()

        epoch_train_cum_loss = 0.0
        
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.long().to(device)

            optimizer.zero_grad()
            
            loss = model(images, labels)

            epoch_train_cum_loss += loss.item() * images.size(0)

            loss.backward()
            optimizer.step()
            
            
        """
        Validation
        """
        model.eval()

        epoch_valid_cum_loss = 0.0

        for images, labels in valid_loader:
            images = images.to(device)
            labels = labels.long().to(device)

            with torch.no_grad():
                loss = model(images, labels)
                epoch_valid_cum_loss += loss.item() * images.size(0)
                
        
        epoch_end_time = time.time() # epoch end time
        epoch_time_used = epoch_end_time - epoch_start_time
        # convert epoch_time_used into mm:ss
        mm = epoch_time_used // 60
        ss = epoch_time_used % 60

        ## Print metrics
        if (epoch+1) % print_every == 0:
            print("Epoch {}/{}\tTrain loss: {:.4f}\tValid loss: {:.4f}\tTime: {:.0f}m {:.0f}s".format(
                epoch+1, num_epochs, train_loss, valid_loss, mm, ss))
            
        
        scheduler.step()
    
    # return the best model
    return model

## ResNet-50 teacher

In [10]:
# load resnet model
teacher = models.resnet50(pretrained=True)
# modify the last dense layer
teacher.fc = nn.Linear(2048, NUM_OF_CLASSES)
# load gz2-pretrained weights
teacher.load_state_dict(torch.load('gz2_resnet50_b64_lr000005_ss10_gamma01_E200_e33.pth'))
# print architecture
print(teacher)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

## Teacher's performance

In [None]:
# move to gpu
teacher = teacher.to(device)

# evaluation
teacher.eval()

y_true = []
y_pred = []

# iterate test data
for images, labels in test_loader:
    images = images.to(device)
    labels = labels.long().to(device)

    with torch.no_grad():
        pred_logits = teacher(images)
        _, pred_classes = torch.max(pred_logits.detach(), dim=1)

        y_true += torch.squeeze(labels.cpu()).tolist()
        y_pred += torch.squeeze(pred_classes).tolist()


gxy_labels = ['Round Elliptical',
                'In-between Elliptical',
                'Cigar-shaped Elliptical',
                'Edge-on Spiral',
                'Barred Spiral',
                'Unbarred Spiral',
                'Irregular',
                'Merger']

# confusion matrix
cm = confusion_matrix(y_true, y_pred)
cm_df = pd.DataFrame(cm, index=gxy_labels, columns=gxy_labels)

print("Confusion matrix:")
print(cm)
print("\n")

# plot
fig = plt.figure(figsize=(15, 10))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="YlGnBu")
plt.show()

# class-wise accuracy
for c in range(8):
    print("Class {}: accuracy = {:.4f} ({})".format(c, cm[c,c]/sum(cm[c,:]), gxy_labels[c]))
print("\n")

# accuracy
acc = accuracy_score(y_true, y_pred)
print("Total Accuracy = {:.4f}\n".format(acc))

# recall
recall = recall_score(y_true, y_pred, average='macro')
print("Recall = {:.4f}\n".format(recall))

# f1 score
F1 = f1_score(y_true, y_pred, average='macro')
print("F1 score = {:.4f}\n".format(F1))

## Student

In [11]:
PATCH_SIZE = 14
DEPTH = 12
HIDDEN_DIM = 128
K_DIM = 64
NUM_HEADS = 8
MLP_DIM = 128
DROPOUT = 0.1
EMB_DROPOUT = 0.1

LR = 3e-4
STEP_SIZE = 5
GAMMA = 0.9
MAX_EPOCH = 200

In [12]:
student = DistillableViT(image_size = 224,
                         patch_size = PATCH_SIZE,
                         num_classes = NUM_OF_CLASSES,
                         dim = HIDDEN_DIM,
                         depth = DEPTH,
                         heads = NUM_HEADS,
                         mlp_dim = MLP_DIM,
                         dropout = DROPOUT,
                         emb_dropout = EMB_DROPOUT)

In [13]:
distiller = DistillWrapper(student = student,
                           teacher = teacher,
                           temperature = 3,           # temperature of distillation
                           alpha = 0.5,               # trade between main loss and distillation loss
                           hard = False)              # whether to use soft or hard distillation

In [14]:
# optimizer
optimizer = optim.Adam(distiller.parameters(), lr=LR)
# scheduler
scheduler = lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

In [15]:
distiller = distiller.to(device)
distiller = train_distiller(distiller, 1, optimizer, scheduler, print_every=1, early_stop_epochs=10)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


RuntimeError: CUDA out of memory. Tried to allocate 34.00 MiB (GPU 0; 4.00 GiB total capacity; 2.84 GiB already allocated; 0 bytes free; 2.92 GiB reserved in total by PyTorch)