In [1]:
from torchvision import datasets, transforms, models
from torch.utils.data import random_split, DataLoader
import os
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
DATA_DIR = "/Users/rishavghosh/Desktop/python/Oasis_dataset"

In [3]:
weights=models.AlexNet_Weights.IMAGENET1K_V1
transform = weights.transforms()

In [4]:
transform

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

In [5]:
full_dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
print("Classes:", full_dataset.classes)

Classes: ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']


In [6]:
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

print("Train size:", len(train_dataset))
print("Test size:", len(test_dataset))

Train size: 96336
Test size: 24085


In [7]:
from torch.utils.data import WeightedRandomSampler
import numpy as np
import torch

# Extract labels from train dataset
train_targets = [full_dataset.samples[i][1] for i in train_dataset.indices]
class_counts = np.bincount(train_targets)

print("Train class counts:", dict(zip(full_dataset.classes, class_counts)))

# Compute class weights
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = [class_weights[label] for label in train_targets]

# Sampler for balanced training
train_sampler = WeightedRandomSampler(weights=sample_weights,
                                      num_samples=len(sample_weights),
                                      replacement=True)


Train class counts: {'Mild Dementia': 11212, 'Moderate Dementia': 5543, 'Non Demented': 61488, 'Very mild Dementia': 18093}


In [8]:
# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [9]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [10]:
model = models.alexnet(weights=weights)
# Replace the classifier’s last layer
model.classifier[6] = nn.Linear(model.classifier[6].in_features, 4)
model = model.to(device)


In [11]:
from torchinfo import summary
summary(model = model,
        input_size = (32, 3, 224, 224),
        col_names = ["input_size", "output_size", "num_params", "trainable"],
        col_width = 20,
        row_settings = ["var_names"]
       )

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
AlexNet (AlexNet)                        [32, 3, 224, 224]    [32, 4]              --                   True
├─Sequential (features)                  [32, 3, 224, 224]    [32, 256, 6, 6]      --                   True
│    └─Conv2d (0)                        [32, 3, 224, 224]    [32, 64, 55, 55]     23,296               True
│    └─ReLU (1)                          [32, 64, 55, 55]     [32, 64, 55, 55]     --                   --
│    └─MaxPool2d (2)                     [32, 64, 55, 55]     [32, 64, 27, 27]     --                   --
│    └─Conv2d (3)                        [32, 64, 27, 27]     [32, 192, 27, 27]    307,392              True
│    └─ReLU (4)                          [32, 192, 27, 27]    [32, 192, 27, 27]    --                   --
│    └─MaxPool2d (5)                     [32, 192, 27, 27]    [32, 192, 13, 13]    --                   --
│    └─Conv2d (6)     

In [12]:
def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0
    correct = 0
    total = 0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

# Evaluation function
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [13]:
for param in model.features.parameters():
    param.requires_grad = False 

for param in model.classifier.parameters():
    param.requires_grad = True
model.to(device)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [14]:
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [15]:
import time
num_epochs = 10

for epoch in range(num_epochs):
    start_time = time.time()

    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(model, test_loader, criterion, device)

    elapsed = time.time() - start_time
    print(f"Epoch {epoch+1}/{num_epochs}, Time: {elapsed:.2f}s")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss  : {val_loss:.4f}, Val Acc  : {val_acc:.4f}")

Epoch 1/10, Time: 666.53s
  Train Loss: 0.5766, Train Acc: 0.7584
  Val Loss  : 0.4024, Val Acc  : 0.8117
Epoch 2/10, Time: 670.48s
  Train Loss: 0.3847, Train Acc: 0.8446
  Val Loss  : 0.3010, Val Acc  : 0.8659
Epoch 3/10, Time: 700.27s
  Train Loss: 0.3085, Train Acc: 0.8790
  Val Loss  : 0.2618, Val Acc  : 0.8857
Epoch 4/10, Time: 726.02s
  Train Loss: 0.2689, Train Acc: 0.8967
  Val Loss  : 0.2403, Val Acc  : 0.8929
Epoch 5/10, Time: 724.40s
  Train Loss: 0.2403, Train Acc: 0.9083
  Val Loss  : 0.2298, Val Acc  : 0.9028
Epoch 6/10, Time: 712.95s
  Train Loss: 0.2149, Train Acc: 0.9190
  Val Loss  : 0.2140, Val Acc  : 0.9048
Epoch 7/10, Time: 713.35s
  Train Loss: 0.2014, Train Acc: 0.9255
  Val Loss  : 0.1448, Val Acc  : 0.9354
Epoch 8/10, Time: 714.61s
  Train Loss: 0.1937, Train Acc: 0.9283
  Val Loss  : 0.1430, Val Acc  : 0.9367
Epoch 9/10, Time: 697.45s
  Train Loss: 0.1768, Train Acc: 0.9355
  Val Loss  : 0.1607, Val Acc  : 0.9293
Epoch 10/10, Time: 720.76s
  Train Loss: 0.170

In [16]:
torch.save(model.state_dict(), 'alexNet_alzheimers_extraData.pth')