In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from vit_pytorch import ViT
from tqdm import tqdm  # Import tqdm
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
train_root_dir="./../processed_RealDS/train"
val_root_dir="./../processed_RealDS/validation"
test_root_dir="./../processed_RealDS/test"

In [4]:
# Define the hyperparameters
batch_size = 32
learning_rate = 1e-4
num_epochs = 35

In [5]:
####################################
# Training
####################################

trans={
    # Train uses data augmentation
    'train':
    transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.4762, 0.3054, 0.2368],
                             [0.3345, 0.2407, 0.2164])
    ]),
    # Validation does not use augmentation
    'valid':
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.4762, 0.3054, 0.2368],
                             [0.3345, 0.2407, 0.2164])
    ]),
    
    # Test does not use augmentation
    'test':
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.4762, 0.3054, 0.2368],
                             [0.3345, 0.2407, 0.2164])
    ]),
}

In [6]:
#Generators
training_dataset = ImageFolder(train_root_dir,transform=trans['train'])
validation_dataset = ImageFolder(val_root_dir,transform=trans['valid'])
test_dataset = ImageFolder(test_root_dir,transform=trans['test'])

train_dataloader = DataLoader(training_dataset,batch_size,shuffle=True) # ** unpacks a dictionary into keyword arguments
val_dataloader = DataLoader(validation_dataset,batch_size)
test_dataloader = DataLoader(test_dataset,batch_size)

print('Number of Training set images:{}'.format(len(training_dataset)))
print('Number of Validation set images:{}'.format(len(validation_dataset)))
print('Number of Test set images:{}'.format(len(test_dataset)))

Number of Training set images:1600
Number of Validation set images:19
Number of Test set images:30


In [7]:
# Initialize the Vision Transformer model
model = ViT(
    image_size=224,
    patch_size=32,
    num_classes=len(training_dataset.classes),
    dim=768,  # Dimension of the model
    depth=12,  # Number of transformer layers
    heads=12,  # Number of attention heads
    mlp_dim=3072,  # Dimension of the MLP layers
    dropout=0.1
)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Set the device to use (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print("Device: ", device)
model.to(device)

Device:  cuda


ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
    (1): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=3072, out_features=768, bias=True)
    (3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-11): 12 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attend): Softmax(dim=-1)
          (dropout): Dropout(p=0.1, inplace=False)
          (to_qkv): Linear(in_features=768, out_features=2304, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=768, out_features=768, bias=True)
            (1): Dropout(p=0.1, inplace=False)
          )
        )
        (1): FeedForward(
          (net): Sequential(
          

## Training

In [8]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    # Use tqdm for a progress bar
    with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch') as pbar:
        for i, (images, labels) in enumerate(train_dataloader):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            # Update the progress bar
            pbar.set_postfix(loss=running_loss / (i + 1))
            pbar.update(1)

    print(f'Training - Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_dataloader)}')

    # Validation
    model.eval()
    val_loss = 0.0
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for images, labels in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    print(f'Validation - Epoch {epoch+1}/{num_epochs}, Loss: {val_loss/len(val_dataloader)}')

    # Compute and print validation accuracy
    accuracy = accuracy_score(true_labels, predicted_labels)
    print(f'Validation - Epoch {epoch+1}/{num_epochs}, Accuracy: {accuracy * 100:.2f}%')

# Save the trained model
torch.save(model.state_dict(), 'vit_model.pth')
print("Training completed and modal saved to vit_model.pth")

Epoch 1/35: 100%|██████████| 50/50 [00:17<00:00,  2.86batch/s, loss=0.673]


Training - Epoch 1/35, Loss: 0.6725830394029617
Validation - Epoch 1/35, Loss: 0.43615612387657166
Validation - Epoch 1/35, Accuracy: 84.21%


Epoch 2/35: 100%|██████████| 50/50 [00:14<00:00,  3.47batch/s, loss=0.346]


Training - Epoch 2/35, Loss: 0.3458257269859314
Validation - Epoch 2/35, Loss: 0.5094606280326843
Validation - Epoch 2/35, Accuracy: 78.95%


Epoch 3/35: 100%|██████████| 50/50 [00:14<00:00,  3.51batch/s, loss=0.303]


Training - Epoch 3/35, Loss: 0.30270934030413627
Validation - Epoch 3/35, Loss: 0.4410644769668579
Validation - Epoch 3/35, Accuracy: 84.21%


Epoch 4/35: 100%|██████████| 50/50 [00:14<00:00,  3.38batch/s, loss=0.195]


Training - Epoch 4/35, Loss: 0.19477710239589213
Validation - Epoch 4/35, Loss: 0.259253591299057
Validation - Epoch 4/35, Accuracy: 89.47%


Epoch 5/35: 100%|██████████| 50/50 [00:14<00:00,  3.46batch/s, loss=0.106] 


Training - Epoch 5/35, Loss: 0.1063788290321827
Validation - Epoch 5/35, Loss: 0.3562778830528259
Validation - Epoch 5/35, Accuracy: 89.47%


Epoch 6/35: 100%|██████████| 50/50 [00:14<00:00,  3.47batch/s, loss=0.098] 


Training - Epoch 6/35, Loss: 0.0980429395288229
Validation - Epoch 6/35, Loss: 0.31843802332878113
Validation - Epoch 6/35, Accuracy: 89.47%


Epoch 7/35: 100%|██████████| 50/50 [00:14<00:00,  3.44batch/s, loss=0.0809]


Training - Epoch 7/35, Loss: 0.08085588354617357
Validation - Epoch 7/35, Loss: 0.35794970393180847
Validation - Epoch 7/35, Accuracy: 89.47%


Epoch 8/35: 100%|██████████| 50/50 [00:14<00:00,  3.45batch/s, loss=0.0646]


Training - Epoch 8/35, Loss: 0.06461690817959606
Validation - Epoch 8/35, Loss: 0.3275861144065857
Validation - Epoch 8/35, Accuracy: 89.47%


Epoch 9/35: 100%|██████████| 50/50 [00:14<00:00,  3.50batch/s, loss=0.056] 


Training - Epoch 9/35, Loss: 0.055961893247440456
Validation - Epoch 9/35, Loss: 0.20932935178279877
Validation - Epoch 9/35, Accuracy: 89.47%


Epoch 10/35: 100%|██████████| 50/50 [00:14<00:00,  3.50batch/s, loss=0.0632]


Training - Epoch 10/35, Loss: 0.06315691477619112
Validation - Epoch 10/35, Loss: 0.381751149892807
Validation - Epoch 10/35, Accuracy: 89.47%


Epoch 11/35: 100%|██████████| 50/50 [00:14<00:00,  3.38batch/s, loss=0.0514]


Training - Epoch 11/35, Loss: 0.05142363966908306
Validation - Epoch 11/35, Loss: 0.4255070388317108
Validation - Epoch 11/35, Accuracy: 89.47%


Epoch 12/35: 100%|██████████| 50/50 [00:14<00:00,  3.48batch/s, loss=0.0646]


Training - Epoch 12/35, Loss: 0.06457997617311775
Validation - Epoch 12/35, Loss: 0.3588961064815521
Validation - Epoch 12/35, Accuracy: 89.47%


Epoch 13/35: 100%|██████████| 50/50 [00:14<00:00,  3.45batch/s, loss=0.0369]


Training - Epoch 13/35, Loss: 0.03692597491433844
Validation - Epoch 13/35, Loss: 0.4226158559322357
Validation - Epoch 13/35, Accuracy: 89.47%


Epoch 14/35: 100%|██████████| 50/50 [00:14<00:00,  3.36batch/s, loss=0.0307]


Training - Epoch 14/35, Loss: 0.03066473395563662
Validation - Epoch 14/35, Loss: 0.4326937198638916
Validation - Epoch 14/35, Accuracy: 89.47%


Epoch 15/35: 100%|██████████| 50/50 [00:14<00:00,  3.41batch/s, loss=0.0405]


Training - Epoch 15/35, Loss: 0.04054987074574456
Validation - Epoch 15/35, Loss: 0.34798556566238403
Validation - Epoch 15/35, Accuracy: 89.47%


Epoch 16/35: 100%|██████████| 50/50 [00:14<00:00,  3.42batch/s, loss=0.0359]


Training - Epoch 16/35, Loss: 0.03587321045808494
Validation - Epoch 16/35, Loss: 0.32558953762054443
Validation - Epoch 16/35, Accuracy: 89.47%


Epoch 17/35: 100%|██████████| 50/50 [00:13<00:00,  3.60batch/s, loss=0.0377]


Training - Epoch 17/35, Loss: 0.03773732263827696
Validation - Epoch 17/35, Loss: 0.30029869079589844
Validation - Epoch 17/35, Accuracy: 89.47%


Epoch 18/35: 100%|██████████| 50/50 [00:13<00:00,  3.59batch/s, loss=0.0356]


Training - Epoch 18/35, Loss: 0.03557224759424571
Validation - Epoch 18/35, Loss: 0.35753005743026733
Validation - Epoch 18/35, Accuracy: 89.47%


Epoch 19/35: 100%|██████████| 50/50 [00:14<00:00,  3.52batch/s, loss=0.0545]


Training - Epoch 19/35, Loss: 0.05447992542642169
Validation - Epoch 19/35, Loss: 0.3151153326034546
Validation - Epoch 19/35, Accuracy: 89.47%


Epoch 20/35: 100%|██████████| 50/50 [00:14<00:00,  3.46batch/s, loss=0.0377]


Training - Epoch 20/35, Loss: 0.037735212070983834
Validation - Epoch 20/35, Loss: 0.34606868028640747
Validation - Epoch 20/35, Accuracy: 89.47%


Epoch 21/35: 100%|██████████| 50/50 [00:14<00:00,  3.44batch/s, loss=0.03]  


Training - Epoch 21/35, Loss: 0.029960242436500265
Validation - Epoch 21/35, Loss: 0.423836350440979
Validation - Epoch 21/35, Accuracy: 89.47%


Epoch 22/35: 100%|██████████| 50/50 [00:14<00:00,  3.40batch/s, loss=0.0271] 


Training - Epoch 22/35, Loss: 0.027087427609367297
Validation - Epoch 22/35, Loss: 0.33514222502708435
Validation - Epoch 22/35, Accuracy: 89.47%


Epoch 23/35: 100%|██████████| 50/50 [00:14<00:00,  3.43batch/s, loss=0.0234]


Training - Epoch 23/35, Loss: 0.02344280424527824
Validation - Epoch 23/35, Loss: 0.2992278039455414
Validation - Epoch 23/35, Accuracy: 89.47%


Epoch 24/35: 100%|██████████| 50/50 [00:14<00:00,  3.46batch/s, loss=0.0211]


Training - Epoch 24/35, Loss: 0.021055261041910852
Validation - Epoch 24/35, Loss: 0.34855756163597107
Validation - Epoch 24/35, Accuracy: 89.47%


Epoch 25/35: 100%|██████████| 50/50 [00:14<00:00,  3.51batch/s, loss=0.0438] 


Training - Epoch 25/35, Loss: 0.04382841961865779
Validation - Epoch 25/35, Loss: 0.24392656981945038
Validation - Epoch 25/35, Accuracy: 89.47%


Epoch 26/35: 100%|██████████| 50/50 [00:14<00:00,  3.35batch/s, loss=0.027] 


Training - Epoch 26/35, Loss: 0.02696623102063313
Validation - Epoch 26/35, Loss: 0.43590033054351807
Validation - Epoch 26/35, Accuracy: 89.47%


Epoch 27/35: 100%|██████████| 50/50 [00:14<00:00,  3.43batch/s, loss=0.0243]


Training - Epoch 27/35, Loss: 0.024306852975860237
Validation - Epoch 27/35, Loss: 0.28790900111198425
Validation - Epoch 27/35, Accuracy: 89.47%


Epoch 28/35: 100%|██████████| 50/50 [00:14<00:00,  3.46batch/s, loss=0.0223]


Training - Epoch 28/35, Loss: 0.022282992944237776
Validation - Epoch 28/35, Loss: 0.42373165488243103
Validation - Epoch 28/35, Accuracy: 89.47%


Epoch 29/35: 100%|██████████| 50/50 [00:14<00:00,  3.35batch/s, loss=0.0173]


Training - Epoch 29/35, Loss: 0.017262013048748484
Validation - Epoch 29/35, Loss: 0.3167019784450531
Validation - Epoch 29/35, Accuracy: 89.47%


Epoch 30/35: 100%|██████████| 50/50 [00:14<00:00,  3.42batch/s, loss=0.0268]


Training - Epoch 30/35, Loss: 0.02675908955250634
Validation - Epoch 30/35, Loss: 0.29010382294654846
Validation - Epoch 30/35, Accuracy: 89.47%


Epoch 31/35: 100%|██████████| 50/50 [00:14<00:00,  3.49batch/s, loss=0.0312]


Training - Epoch 31/35, Loss: 0.031233067314169603
Validation - Epoch 31/35, Loss: 0.3330632448196411
Validation - Epoch 31/35, Accuracy: 89.47%


Epoch 32/35: 100%|██████████| 50/50 [00:14<00:00,  3.42batch/s, loss=0.0139] 


Training - Epoch 32/35, Loss: 0.013896803817842737
Validation - Epoch 32/35, Loss: 0.4039005935192108
Validation - Epoch 32/35, Accuracy: 89.47%


Epoch 33/35: 100%|██████████| 50/50 [00:14<00:00,  3.36batch/s, loss=0.0117]


Training - Epoch 33/35, Loss: 0.011742404209398956
Validation - Epoch 33/35, Loss: 0.42254528403282166
Validation - Epoch 33/35, Accuracy: 89.47%


Epoch 34/35: 100%|██████████| 50/50 [00:14<00:00,  3.46batch/s, loss=0.0257]


Training - Epoch 34/35, Loss: 0.025715493727911962
Validation - Epoch 34/35, Loss: 0.44618356227874756
Validation - Epoch 34/35, Accuracy: 89.47%


Epoch 35/35: 100%|██████████| 50/50 [00:14<00:00,  3.54batch/s, loss=0.0287]


Training - Epoch 35/35, Loss: 0.028678646354237572
Validation - Epoch 35/35, Loss: 0.34118860960006714
Validation - Epoch 35/35, Accuracy: 89.47%
Training completed and modal saved to vit_model.pth


## Testing

## Visualizing Results

In [10]:
# plt.figure(figsize=(20, 16))
# sns.set(font_scale=1.5)
# sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
# plt.xlabel('Predicted')
# plt.ylabel('True')
# plt.title('Confusion Matrix')