In [None]:
import torch
import random
import timeit
import numpy 
from torch import optim
from torch import nn
from torchinfo import summary
import torch.nn.functional as F
from torchvision import transforms, datasets
import torchvision
from torch.utils.data import DataLoader, Dataset
from dataclasses import dataclass
from torch.optim.lr_scheduler import CosineAnnealingLR
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
@dataclass
class vit_config:
    num_channels: int = 3
    batch_size:int = 16
    image_size: int = 224
    patch_size: int = 16
    num_heads:int = 8
    dropout: float = 0.0
    hidden_size: int = 768
    layer_norm_eps: float = 1e-6
    num_encoder_layers: int = 12
    random_seed: int = 42
    epochs: int = 30
    num_classes: int = 10  
    learning_rate: float = 1e-3
    adam_weight_decay: int = 0
    adam_betas: tuple = (0.9, 0.999)
    embd_dim: int = (patch_size ** 2) * num_channels           # 768
    num_patches: int = (image_size // patch_size) ** 2         # 196
    device: str = "cuda" if torch.cuda.is_available() else "cpu" 
    

In [None]:
config = vit_config

random.seed(config.random_seed)
numpy.random.seed(config.random_seed)
torch.manual_seed(config.random_seed)
torch.cuda.manual_seed(config.random_seed)
torch.cuda.manual_seed_all(config.random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

### Get pretrained model weight from torch vision 

In [None]:
model_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
vit = torchvision.models.vit_b_16(weights=model_weights).to(device=config.device)

In [None]:
summary(model=vit, 
        input_size=(32, 3, 224, 224), 
        col_names= ["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings= ["var_names"]
    )

### let's freeze all the pretrained parameters

In [None]:
for parameter in vit.parameters():
    parameter.requires_grad = False
    
# adding a linear layer for training
vit.heads = nn.Linear(in_features=768, out_features=config.num_classes).to(config.device)

In [None]:
summary(model=vit, 
        input_size=(32, 3, 224, 224), 
        col_names= ["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings= ["var_names"]
    )

### Data preparation

In [None]:
train_data_dir = "flowers-data/train"
val_data_dir = "flowers-data/valid"
test_data_dir = "flowers-data/test"

In [None]:
class TrainDataset(Dataset):
    def __init__(self, root_dir, config:vit_config):
        self.dataset = datasets.ImageFolder(root=root_dir)
        self.image_size = config.image_size
        self.classes = self.dataset.classes
        self.class_to_idx = self.dataset.class_to_idx
        
        self.transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.RandomResizedCrop(config.image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        img_path, label = self.dataset.samples[index]
        
        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
         
        return {
            "image": image,
            "label": label,
            "index": index
        }
        
        
class ValidationDataset(Dataset):
    def __init__(self, root_dir, config:vit_config):
        self.dataset = datasets.ImageFolder(root=root_dir)
        self.image_size = config.image_size
        self.classes = self.dataset.classes
        self.class_to_idx = self.dataset.class_to_idx
        
        self.transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        img_path, label = self.dataset.samples[index]
        
        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        return {
            "image": image,
            "label": label,
            "index": index
        }
        
        
class TestDataset(Dataset):
    def __init__(self, root_dir, config:vit_config):
        self.dataset = datasets.ImageFolder(root=root_dir)
        self.image_size = config.image_size
        self.classes = self.dataset.classes
        self.class_to_idx = self.dataset.class_to_idx
        
        self.transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        img_path, label = self.dataset.samples[index]
        
        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        return {
            "image": image,
            "label" : label,
            "index": index
        }
        
        
train_dataset = TrainDataset(train_data_dir, vit_config)
val_dataset = ValidationDataset(val_data_dir, vit_config)
test_dataset = TestDataset(test_data_dir, vit_config)

train_dataloader = DataLoader(
                        dataset=train_dataset,
                        batch_size=config.batch_size,
                        shuffle=True
                    )

val_dataloader = DataLoader(
                        dataset=val_dataset,
                        batch_size=config.batch_size,
                        shuffle=False
                    )

test_dataloader = DataLoader(
                        dataset=test_dataset,
                        batch_size=config.batch_size,
                        shuffle=False
                    )


In [None]:
classes = train_dataset.classes  
classes

### Training

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit.heads.parameters(), betas=config.adam_betas, lr = config.learning_rate, weight_decay=config.adam_weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=config.epochs)

start = timeit.default_timer()

for epoch in range(config.epochs):
    
    # training
    vit.train()
    
    train_labels = []
    train_preds = []
    train_running_loss = 0
    
    for idx, image in enumerate(tqdm(train_dataloader, position=0, desc="training")):
        img = image["image"].float().to(config.device)
        label = image["label"].type(torch.uint8).to(config.device)
        
        y_pred = vit(img)
        y_pred_label = torch.argmax(y_pred, dim=1)
        
        train_labels.extend(label.cpu().detach())
        train_preds.extend(y_pred_label.cpu().detach())
        
        loss = criterion(y_pred, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_running_loss += loss.item()
        
    train_loss = train_running_loss/ (idx+1)
    
    #validation
    vit.eval()
    valid_labels = []
    valid_preds = []
    valid_running_loss = 0
    with torch.no_grad():
        for idx, image in enumerate(tqdm(val_dataloader, position=0, desc="validation")):
            img = image["image"].float().to(config.device)
            label = image["label"].type(torch.uint8).to(config.device)
            
            y_pred = vit(img)
            y_pred_label = torch.argmax(y_pred, dim=1)
            
            valid_labels.extend(label.cpu().detach())
            valid_preds.extend(y_pred_label.cpu().detach())
            
            loss = criterion(y_pred, label)
            valid_running_loss += loss.item()
            
    val_loss = valid_running_loss/(idx + 1)
    
    print("-"*30)
    
    scheduler.step()
    print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    print("-"*30)
    print(f"Train loss epoch {epoch + 1}: {train_loss:.4f}")
    print(f"Valid loss epoch {epoch + 1}: {val_loss:.4f}") 
    print(f"Train accuracy epoch {epoch+1} : {sum(1 for x, y in zip(train_labels, train_preds) if x==y)/len(train_labels):.4f}")
    print(f"Valid accuracy epoch {epoch+1} : {sum(1 for x, y in zip(valid_labels, valid_preds) if x==y)/len(valid_labels):.4f}")
        
stop = timeit.default_timer()
print(f"Training Time: {stop-start:.2f}s") 


### Testing

In [None]:
test_images = []
test_preds = []
test_labels = []

vit.eval()
with torch.no_grad():
    for idx, image in enumerate(tqdm(test_dataloader, position=0, desc="Testing")):
        img = image["image"].to(config.device)
        label = image["label"].to(config.device)
        
        y_pred = vit(img)
        y_pred_label = torch.argmax(y_pred, dim=1)
        
        test_images.extend(img.cpu().detach())
        test_preds.extend([int(i) for i in y_pred_label])
        test_labels.extend(label.cpu().detach())

    print(f"Test accuracy : {sum(1 for x, y in zip(test_labels, test_preds) if x==y)/len(test_labels):.4f}")

In [None]:
plt.figure(figsize=(8, 5))  # Larger figure for better visibility
f, axarr = plt.subplots(2, 3, figsize=( 8, 5))
counter = 0

for i in range(2):
    for j in range(3):
        # Display the image
        axarr[i][j].imshow(test_images[counter].squeeze().permute(1, 2, 0))
        
        # Get predicted and actual labels
        pred_label = classes[test_preds[counter]]
        actual_label = classes[test_labels[counter]]  # Assuming test_labels contains the true labels
        
        # Set title to show both predicted and actual labels
        axarr[i][j].set_title(f"Pred: {pred_label}\nActual: {actual_label}", fontsize=12)
        
        # Remove axis ticks for cleaner display
        axarr[i][j].set_xticks([])
        axarr[i][j].set_yticks([])
        
        counter += 1

plt.tight_layout(pad=3.0)  # Add padding between subplots
plt.suptitle("Model Predictions vs Actual Labels", fontsize=16, y=0.98)
plt.show()

In [None]:
# saving full model
torch.save(vit, 'transfer_learning_vit_full.pth')