In [1]:
import os
import torch
import torchvision

from torch import nn
from torchvision import transforms
import matplotlib.pyplot as plt


torch.manual_seed(42)

<torch._C.Generator at 0x1ebe0810f10>

In [2]:
base_dir = 'dataset-vit'
train_dir = os.path.join(base_dir + '/train')
test_dir = os.path.join(base_dir + '/test')
os.listdir(test_dir)

['parasitized', 'uninfected']

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [4]:
# 1. Get pretrained weights for ViT-Base
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT 

# 2. Setup a ViT model instance with pretrained weights
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)

# 3. Freeze the base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False
    
class_names = ['parasitized', 'uninfected']
pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)

In [5]:
from torchinfo import summary

summary(model=pretrained_vit, 
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 2]              768                  Partial
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 768, 14, 14]    (590,592)            False
├─Encoder (encoder)                                          [32, 197, 768]       [32, 197, 768]       151,296              False
│    └─Dropout (dropout)                                     [32, 197, 768]       [32, 197, 768]       --                   --
│    └─Sequential (layers)                                   [32, 197, 768]       [32, 197, 768]       --                   False
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 768]       [32, 197, 768]       (7,087,872)          False
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 768]       [32, 

In [6]:
pretrained_vit_transforms = pretrained_vit_weights.transforms()
print(pretrained_vit_transforms)

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [7]:
NUM_WORKERS = os.cpu_count()

def create_dataloaders(train_dir, test_dir, transform, batch_size, num_workers=NUM_WORKERS):
    
    #train and test data
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    test_data = datasets.ImageFolder(test_dir, transform=transform)
    
    #class names
    class_names = train_data.classes
    
    train_dataloader = DataLoader(train_data, batch_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    test_dataloader = DataLoader(test_data, batch_size, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    
    return train_dataloader, test_dataloader, class_names
    


In [8]:
# Setup dataloaders
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_dataloader_pretrained, test_dataloader_pretrained, class_names = create_dataloaders(train_dir=train_dir,
                                                                                          test_dir=test_dir,
                                                                                          transform=pretrained_vit_transforms,
                                                                                          batch_size=32) 

In [9]:
from going_modular.going_modular import engine

# Create optimizer and loss function
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(), 
                             lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

# Train the classifier head of the pretrained ViT feature extractor model
pretrained_vit_results = engine.train(model=pretrained_vit,
                                      train_dataloader=train_dataloader_pretrained,
                                      test_dataloader=test_dataloader_pretrained,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      epochs=10,
                                      device=device)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.2650 | train_acc: 0.9212 | test_loss: 0.2804 | test_acc: 0.8899
Epoch: 2 | train_loss: 0.0842 | train_acc: 0.9750 | test_loss: 0.2928 | test_acc: 0.8884
Epoch: 3 | train_loss: 0.0653 | train_acc: 0.9743 | test_loss: 0.2657 | test_acc: 0.8839
Epoch: 4 | train_loss: 0.0563 | train_acc: 0.9812 | test_loss: 0.2630 | test_acc: 0.8884
Epoch: 5 | train_loss: 0.0502 | train_acc: 0.9828 | test_loss: 0.2733 | test_acc: 0.8929
Epoch: 6 | train_loss: 0.0455 | train_acc: 0.9821 | test_loss: 0.2620 | test_acc: 0.8988
Epoch: 7 | train_loss: 0.0414 | train_acc: 0.9837 | test_loss: 0.2809 | test_acc: 0.8988
Epoch: 8 | train_loss: 0.0368 | train_acc: 0.9868 | test_loss: 0.2801 | test_acc: 0.8988
Epoch: 9 | train_loss: 0.0364 | train_acc: 0.9891 | test_loss: 0.2628 | test_acc: 0.9048
Epoch: 10 | train_loss: 0.0328 | train_acc: 0.9938 | test_loss: 0.3039 | test_acc: 0.8988
