imports

In [1]:
import numpy as np

from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import Dataset ,DataLoader, WeightedRandomSampler
from torch.utils.data.dataset import Subset
import os
from PIL import Image

from torchvision.transforms import ToTensor

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fdbb2b9e0b0>

In [97]:
import torch.nn as nn

class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, num_layers, hidden_dim, num_heads, mlp_dim, dropout_rate=0.1):
        super(VisionTransformer, self).__init__()
        
        num_patches_height = (image_size - patch_size) // patch_size + 1
        num_patches_width = (image_size - patch_size) // patch_size + 1
        num_patches = num_patches_height * num_patches_width

        patch_dim = 256  # Adjust patch dimension to match positional embedding dimension
        # num_patches = (image_size // patch_size) ** 2
        # patch_dim = patch_size ** 2  # Assuming input images are RGB
        
        # Patch embedding layer
        self.patch_embedding = nn.Conv2d(in_channels=1, out_channels=patch_dim, kernel_size=patch_size, stride=patch_size)
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, patch_dim))
        
        # Transformer Encoder
        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=patch_dim, nhead=num_heads), num_layers)
        
        # Classification head
        # self.cls_token = nn.Parameter(torch.randn(32, patch_dim, 1, 1))
        self.cls_token = nn.Parameter(torch.randn(1,1,patch_dim))
        self.mlp_head = nn.Linear(patch_dim, num_classes)
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, x):
        patches = self.patch_embedding(x)
        batch_size = patches.size(0)
        patch_size = patches.size(2)
        
        print("Shape of patches:", patches.shape)
        print("Shape of positional embedding:", self.positional_embedding.shape)
        # Calculate the number of patches
        num_patches_height = patches.size(2)
        num_patches_width = patches.size(3)
        num_patches = num_patches_height * num_patches_width
        
        # Add positional embeddings
        patches = patches + self.positional_embedding[:, :num_patches_height, :num_patches_width]
        
        # # Add positional embeddings old
        # patches = patches + self.positional_embedding[:, :patch_size, :]
        
        # # # Concatenate classification token to patches
        cls_token = self.cls_token.expand(batch_size, -1, -1, -1)
        print(patches.shape)
        print(cls_token.shape)
        
        patches = torch.cat([cls_token, patches], dim=1)
        
        # Reshape patches for transformer input
        patches = patches.permute(1, 0, 2)  # (seq_len, batch_size, embed_dim)
        
        # Transformer Encoder
        encoder_output = self.encoder(patches)
        
        # Extract classification token representation
        cls_token = encoder_output[0]  # First token is the classification token
        
        # Classification head
        cls_token = self.dropout(cls_token)
        output = self.mlp_head(cls_token)
        
        
        return output


define dataset

In [98]:
class AlzheimerDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the image categories.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # List the categories
        categories = ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']
        label_mapping = {category: idx for idx, category in enumerate(categories)}

        for category in categories:
            category_path = os.path.join(root_dir, category)
            for img_name in os.listdir(category_path):
                self.image_paths.append(os.path.join(category_path, img_name))
                self.labels.append(label_mapping[category])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')  # Convert to grayscale if not already

        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        # Convert label to tensor
        label_tensor = torch.tensor(label, dtype=torch.long)

        return image, label_tensor

In [99]:
# Define transformations for data augmentation
from torch.utils.data import WeightedRandomSampler, random_split
from torchvision import transforms

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Create an instance of the dataset
dataset = AlzheimerDataset(root_dir='./Data', transform=transform)

# Calculate class weights for balancing
class_sample_count = np.array([len(np.where(np.array(dataset.labels) == t)[0]) for t in np.unique(dataset.labels)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in dataset.labels])

# Define the sizes for train, validation, and test sets
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Extract the indices of the train_dataset to get the corresponding weights
train_indices = train_dataset.indices
train_weights = samples_weight[train_indices]

# Create the sampler for the training set
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))

# Number of workers for data loading
num_workers = 0  # Adjust based on your system’s capability

# Create the DataLoaders with the sampler for the training set
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    sampler=train_sampler,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True  # Helps with faster data transfer to CUDA
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)


show data shapes in 1 batch

In [100]:
# Fetch one mini-batch from the DataLoader
for images, labels in train_loader:
    print(f'Batch images shape: {images.shape}')  # Should be [batch_size, 1, height, width]
    print(f'Batch labels shape: {labels.shape}')  # Should be [batch_size]

    # Now iterate through the mini-batch to check each image and label
    for i in range(images.size(0)):  # Loop through the batch
        image_shape = images[i].shape
        label = labels[i]
        print(f'Image {i+1} shape: {image_shape}, Label: {label}')

    # Break after one batch to limit output
    break

Batch images shape: torch.Size([32, 1, 248, 496])
Batch labels shape: torch.Size([32])
Image 1 shape: torch.Size([1, 248, 496]), Label: 3
Image 2 shape: torch.Size([1, 248, 496]), Label: 2
Image 3 shape: torch.Size([1, 248, 496]), Label: 1
Image 4 shape: torch.Size([1, 248, 496]), Label: 3
Image 5 shape: torch.Size([1, 248, 496]), Label: 2
Image 6 shape: torch.Size([1, 248, 496]), Label: 2
Image 7 shape: torch.Size([1, 248, 496]), Label: 0
Image 8 shape: torch.Size([1, 248, 496]), Label: 1
Image 9 shape: torch.Size([1, 248, 496]), Label: 2
Image 10 shape: torch.Size([1, 248, 496]), Label: 2
Image 11 shape: torch.Size([1, 248, 496]), Label: 3
Image 12 shape: torch.Size([1, 248, 496]), Label: 2
Image 13 shape: torch.Size([1, 248, 496]), Label: 0
Image 14 shape: torch.Size([1, 248, 496]), Label: 1
Image 15 shape: torch.Size([1, 248, 496]), Label: 0
Image 16 shape: torch.Size([1, 248, 496]), Label: 3
Image 17 shape: torch.Size([1, 248, 496]), Label: 2
Image 18 shape: torch.Size([1, 248, 49

train

In [101]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define your Vision Transformer model
image_size = 224
patch_size = 16
num_classes = 4
num_layers = 6
hidden_dim = 256
num_heads = 8
mlp_dim = 512
dropout_rate = 0.1
vit_model = VisionTransformer(image_size, patch_size, num_classes, num_layers, hidden_dim, num_heads, mlp_dim, dropout_rate)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=0.001)

# Define directory for model checkpoints
checkpoint_dir = 'ModelCheckpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    vit_model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = vit_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(train_dataset)
    
    # Validation phase
    vit_model.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = vit_model(inputs)
            loss = criterion(outputs, labels)
            running_val_loss += loss.item() * inputs.size(0)
    epoch_val_loss = running_val_loss / len(val_dataset)

    # Print epoch training and validation losses
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}')

    # Save model checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f'vit_model_epoch_{epoch+1}.pt')
    torch.save(vit_model.state_dict(), checkpoint_path)
    print(f'Model checkpoint saved at {checkpoint_path}')

print('Finished Training')

Shape of patches: torch.Size([32, 256, 15, 31])
Shape of positional embedding: torch.Size([1, 197, 256])
torch.Size([32, 256, 15, 31])
torch.Size([32, 1, 1, 256])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 15 for tensor number 1 in the list.