In [None]:
import torch

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import os

In [3]:
img_size =224  #standard ViT size
batch_size =16


In [None]:
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),

    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )

])



In [None]:
test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),

    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )

])



In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# data_path = "/content/drive/MyDrive/Combined Dataset/train"
data_path = ""

dataset = datasets.ImageFolder(data_path, transform=train_transform)

num_class = len(dataset.classes)
num_class

4

In [8]:
train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4
)



In [None]:
# test_data_path = "/content/drive/MyDrive/Combined Dataset/test"
test_data_path =""

test_dataset = datasets.ImageFolder(test_data_path, transform=test_transform)

num_class = len(test_dataset.classes)
num_class

4

In [None]:
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4, 
    pin_memory=True
)




In [9]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels =3, embed_dim = 768):
        super().__init__()

        self.patch_size = patch_size
        self.embed_dim =embed_dim
        self.num_patches = (img_size // patch_size) **2

        self.projection = nn.Conv2d(
            in_channels = in_channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x):
        # x=[B,C,H,W]
        x= self.projection(x)
        # [B, embed_dim, H/p, W/p]

        x= x.flatten(2)
        # [B, embed_dim, num_patches]

        x= x.transpose(1,2)
        # [B, num_patches, embed_dim]

        return x

In [10]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()

        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim//num_heads

        self.qkv = nn.Linear(embed_dim,embed_dim*3)
        self.fc_out= nn.Linear(embed_dim,embed_dim)


    def forward(self, x):
        # B = 16(batch_size) N = 197(head_dim) D = 768(embed_dim)
        b,n,d =x.shape

        qkv = self.qkv(x) #[B, N, 3D]

        qkv = qkv.reshape(b,n, 3, self.num_heads, self.head_dim) #[B, N, 3, H, Hd]

        qkv = qkv.permute(2,0,3,1,4)# [3, B, H, N, Hd]


        q,k,v = qkv[0], qkv[1], qkv[2] #each [B, H, N, Hd]

        # Q: [B, H, N, Hd]
        # Kᵀ: [B, H, Hd, N]
        # attention: [B, H, N, N]
        attention = (q @ k.transpose(-2,-1)) / np.sqrt(self.head_dim)

        attention = attention.softmax(dim =-1)

        #[B, H, N, N] @ [B, H, N, Hd] → [B, H, N, Hd]
        out = attention @ v
        # [B, N, D]
        out = out.transpose(1,2).reshape(b,n,d)
        return self.fc_out(out)

In [None]:
class TransfromerEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio =4):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)

        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim,embed_dim*mlp_ratio),
            nn.GELU(),
            nn.Linear(embed_dim*mlp_ratio, embed_dim)
        )

    def forward(self, x):
        x = x+ self.attn(self.norm1(x))
        x = x+ self.mlp(self.norm2(x))
        return x


In [None]:
class VisionTransfromer(nn.Module):
    def __init__(self,
        img_size=224,
        patch_size = 16,
        in_channels=3,
        num_classes = 4,
        embed_dim =768,
        depth =6, #6
        num_heads = 8
    ):
        super().__init__()
        self.patch_embed =PatchEmbedding(img_size,patch_size,in_channels,embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1,1+self.patch_embed.num_patches, embed_dim)
        )

        self.encoder = nn.Sequential(
            *[TransfromerEncoder(embed_dim,num_heads) for _ in range(depth)]
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B=x.shape[0]

        x= self.patch_embed(x)
        cls_token =self.cls_token.expand(B, -1, -1)

        x = torch.cat((cls_token, x), dim=1)
        x = x+self.pos_embed
        x= self.encoder(x)
        x = self.norm(x)

        cls_output = x[:,0]
        return self.head(cls_output)



In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# print(device)

model = VisionTransfromer(
    num_classes=num_class
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(),lr=3e-4)


In [None]:
EPOCHS = 10

for epoch in range(EPOCHS):
    model.train()

    total_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)             
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        _, preds = torch.max(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / len(train_loader)
    train_acc = 100.0 * correct / total

    print(
        f"Epoch [{epoch+1}/{EPOCHS}] "
        f"Loss: {avg_loss:.4f} "
        f"Train Acc: {train_acc:.2f}%"
    )


Epoch [1/10] Loss: 1.4744 Train Acc: 24.76%
Epoch [2/10] Loss: 1.4035 Train Acc: 24.92%
Epoch [3/10] Loss: 1.3970 Train Acc: 25.12%
Epoch [4/10] Loss: 1.3948 Train Acc: 24.49%
Epoch [5/10] Loss: 1.3913 Train Acc: 24.83%
Epoch [6/10] Loss: 1.3903 Train Acc: 25.30%
Epoch [7/10] Loss: 1.3904 Train Acc: 24.37%
Epoch [8/10] Loss: 1.3888 Train Acc: 25.05%
Epoch [9/10] Loss: 1.3888 Train Acc: 24.76%
Epoch [10/10] Loss: 1.3887 Train Acc: 25.01%


In [15]:
torch.save(model.state_dict(), "vit_weights.pth")


In [16]:
torch.save(model, "vit_full_model.pth")

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import pandas as pd

def run_evaluation(model, loader, device, class_names):
    all_preds = []
    all_labels = []

    model.eval()
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    print("\n--- Classification Report ---")
    print(classification_report(all_labels, all_preds, target_names=class_names))

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.show()

run_evaluation(model, train_loader, device, dataset.classes)

In [None]:
model = VisionTransfromer(
    img_size=224,
    patch_size=16,
    in_channels=3,
    num_classes=4, 
    embed_dim=768,
    depth=4,
    num_heads=8
).to(device)

model.load_state_dict(torch.load("vit_weights.pth", map_location=device))
model.eval()

VisionTransfromer(
  (patch_embed): PatchEmbedding(
    (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (encoder): Sequential(
    (0): TransfromerEncoder(
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadSelfAttention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (fc_out): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm2): Linear(in_features=768, out_features=768, bias=True)
      (mlp): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
    (1): TransfromerEncoder(
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadSelfAttention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (fc_out): Linear(in_features=768, out_features=768, bias=True)
      )


In [None]:
from PIL import Image

def predict_image(image_path, model, transform, device, class_names):
    img = Image.open(image_path).convert("RGB")

    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img_tensor)
        probabilities = torch.nn.functional.softmax(output, dim=1)
        conf, pred = torch.max(probabilities, 1)

    class_idx = pred.item()
    print(f"Prediction: {class_names[class_idx]} ({conf.item()*100:.2f}%)")

predict_image("/content/drive/MyDrive/Combined Dataset/test/Mild Impairment/1 (10).jpg", model, train_transform, device, dataset.classes)

Prediction: Moderate Impairment (26.66%)
