<a href="https://colab.research.google.com/github/ssudhanshu488/SwinOnAlziehmer/blob/main/SwinTransformerOnAlziehmer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch torchvision timm pandas scikit-learn

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [17]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import timm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

In [18]:
# Define the dataset class
class AlzheimerDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [19]:
# Load the dataset
def load_dataset(image_folder):
    image_paths = []
    labels = []
    for image_name in os.listdir(image_folder):
        image_path = os.path.join(image_folder, image_name)
        label = image_name.split('_')[0]  # Assuming the class is the first part of the filename
        image_paths.append(image_path)
        labels.append(label)
    return image_paths, labels

In [20]:
# Preprocess the dataset
def preprocess_dataset(image_paths, labels):
    label_encoder = LabelEncoder()
    labels = label_encoder.fit_transform(labels)
    train_paths, val_paths, train_labels, val_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42)
    return train_paths, val_paths, train_labels, val_labels, label_encoder

In [21]:
# Define the transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [22]:
# Load the dataset
image_folder = '/content/All_img_diff_name'
image_paths, labels = load_dataset(image_folder)
train_paths, val_paths, train_labels, val_labels, label_encoder = preprocess_dataset(image_paths, labels)

In [8]:
!unzip All_img_diff_name.zip

Archive:  All_img_diff_name.zip
   creating: All_img_diff_name/
  inflating: All_img_diff_name/AD_1.jpg  
  inflating: All_img_diff_name/AD_10.jpg  
  inflating: All_img_diff_name/AD_100.jpg  
  inflating: All_img_diff_name/AD_1000.jpg  
  inflating: All_img_diff_name/AD_1001.jpg  
  inflating: All_img_diff_name/AD_1002.jpg  
  inflating: All_img_diff_name/AD_1003.jpg  
  inflating: All_img_diff_name/AD_1004.jpg  
  inflating: All_img_diff_name/AD_1005.jpg  
  inflating: All_img_diff_name/AD_1006.jpg  
  inflating: All_img_diff_name/AD_1007.jpg  
  inflating: All_img_diff_name/AD_1008.jpg  
  inflating: All_img_diff_name/AD_1009.jpg  
  inflating: All_img_diff_name/AD_101.jpg  
  inflating: All_img_diff_name/AD_1010.jpg  
  inflating: All_img_diff_name/AD_1011.jpg  
  inflating: All_img_diff_name/AD_1012.jpg  
  inflating: All_img_diff_name/AD_1013.jpg  
  inflating: All_img_diff_name/AD_1014.jpg  
  inflating: All_img_diff_name/AD_1015.jpg  
  inflating: All_img_diff_name/AD_1016.jpg 

In [23]:
# Create the datasets
train_dataset = AlzheimerDataset(train_paths, train_labels, transform=transform)
val_dataset = AlzheimerDataset(val_paths, val_labels, transform=transform)

In [24]:
# Create the dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [25]:
# Load the pre-trained Swin Transformer model
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=3)

In [26]:
# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [27]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [28]:
from sklearn.metrics import accuracy_score, f1_score, precision_score

# Function to compute metrics
def compute_metrics(true_labels, predicted_labels):
    accuracy = accuracy_score(true_labels, predicted_labels)
    f1 = f1_score(true_labels, predicted_labels, average='weighted')  # Use 'weighted' for multi-class
    precision = precision_score(true_labels, predicted_labels, average='weighted')  # Use 'weighted' for multi-class
    return accuracy, f1, precision

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

    # Validation loop after each epoch
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute metrics
    accuracy, f1, precision = compute_metrics(all_labels, all_preds)
    print(f'Validation Metrics after Epoch {epoch+1}:')
    print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, Precision: {precision:.4f}')

Epoch [1/10], Loss: 0.9709
Validation Metrics after Epoch 1:
Accuracy: 0.6413, F1 Score: 0.6247, Precision: 0.6521
Epoch [2/10], Loss: 0.6918
Validation Metrics after Epoch 2:
Accuracy: 0.7248, F1 Score: 0.6964, Precision: 0.7664
Epoch [3/10], Loss: 0.4312
Validation Metrics after Epoch 3:
Accuracy: 0.7559, F1 Score: 0.7545, Precision: 0.7583
Epoch [4/10], Loss: 0.2636
Validation Metrics after Epoch 4:
Accuracy: 0.8107, F1 Score: 0.8123, Precision: 0.8253
Epoch [5/10], Loss: 0.1182
Validation Metrics after Epoch 5:
Accuracy: 0.9078, F1 Score: 0.9077, Precision: 0.9128
Epoch [6/10], Loss: 0.0927
Validation Metrics after Epoch 6:
Accuracy: 0.9153, F1 Score: 0.9153, Precision: 0.9153
Epoch [7/10], Loss: 0.0477
Validation Metrics after Epoch 7:
Accuracy: 0.9203, F1 Score: 0.9198, Precision: 0.9218
Epoch [8/10], Loss: 0.0329
Validation Metrics after Epoch 8:
Accuracy: 0.7621, F1 Score: 0.7546, Precision: 0.8448
Epoch [9/10], Loss: 0.0751
Validation Metrics after Epoch 9:
Accuracy: 0.9253, F

In [29]:

# Final validation after training
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())


In [30]:
# Compute final metrics
accuracy, f1, precision = compute_metrics(all_labels, all_preds)
print('Final Validation Metrics:')
print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, Precision: {precision:.4f}')


Final Validation Metrics:
Accuracy: 0.9191, F1 Score: 0.9189, Precision: 0.9210


In [31]:
# Save the model
torch.save(model.state_dict(), 'swin_transformer_alzheimer.pth')