In [1]:
import os
import torch
import torchvision

from torch import nn
from torchvision import transforms
import matplotlib.pyplot as plt


torch.manual_seed(42)

<torch._C.Generator at 0x238e3acaf50>

In [2]:
base_dir = 'dataset-vit'
train_dir = os.path.join(base_dir + '/train')
test_dir = os.path.join(base_dir + '/test')
os.listdir(test_dir)

['parasitized', 'uninfected']

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [4]:
# 1. Get pretrained weights for Swin Tiny
pretrained_swin_weights = torchvision.models.Swin_T_Weights.DEFAULT 

# 2. Setup a Swin model instance with pretrained weights
pretrained_swin = torchvision.models.swin_t().to(device)

# 3. Freeze the base parameters
for parameter in pretrained_swin.parameters():
    parameter.requires_grad = False
    
class_names = ['parasitized', 'uninfected']
pretrained_swin.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)

In [5]:
from torchinfo import summary

summary(model=pretrained_swin, 
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                 Input Shape          Output Shape         Param #              Trainable
SwinTransformer (SwinTransformer)                       [32, 3, 224, 224]    [32, 1000]           1,538                Partial
├─Sequential (features)                                 [32, 3, 224, 224]    [32, 7, 7, 768]      --                   False
│    └─Sequential (0)                                   [32, 3, 224, 224]    [32, 56, 56, 96]     --                   False
│    │    └─Conv2d (0)                                  [32, 3, 224, 224]    [32, 96, 56, 56]     (4,704)              False
│    │    └─Permute (1)                                 [32, 96, 56, 56]     [32, 56, 56, 96]     --                   --
│    │    └─LayerNorm (2)                               [32, 56, 56, 96]     [32, 56, 56, 96]     (192)                False
│    └─Sequential (1)                                   [32, 56, 56, 96]     [32, 56, 56, 96]     --                   Fal

In [6]:
pretrained_swin_transforms = pretrained_swin_weights.transforms()
print(pretrained_swin_weights)

ImageClassification(
    crop_size=[224]
    resize_size=[232]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BICUBIC
)


In [7]:
NUM_WORKERS = os.cpu_count()

def create_dataloaders(train_dir, test_dir, transform, batch_size, num_workers=NUM_WORKERS):
    
    #train and test data
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    test_data = datasets.ImageFolder(test_dir, transform=transform)
    
    #class names
    class_names = train_data.classes
    
    train_dataloader = DataLoader(train_data, batch_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    test_dataloader = DataLoader(test_data, batch_size, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    
    return train_dataloader, test_dataloader, class_names
    


In [10]:
# Setup dataloaders
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_dataloader_pretrained, test_dataloader_pretrained, class_names = create_dataloaders(train_dir=train_dir,
                                                                                          test_dir=test_dir,
                                                                                          transform=pretrained_swin_transforms,
                                                                                          batch_size=32) 

In [11]:
from going_modular.going_modular import engine

# Create optimizer and loss function
optimizer = torch.optim.Adam(params=pretrained_swin.parameters(), 
                             lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

# Train the classifier head of the pretrained  feature extractor model
pretrained_swin_results = engine.train(model=pretrained_swin,
                                      train_dataloader=train_dataloader_pretrained,
                                      test_dataloader=test_dataloader_pretrained,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      epochs=5,
                                      device=device)

  0%|          | 0/5 [00:00<?, ?it/s]

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn