In [1]:
from torchvision import datasets, transforms, models
from torch.utils.data import random_split, DataLoader
import os
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
DATA_DIR = "/Users/rishavghosh/Desktop/python/Oasis_dataset"

In [3]:
weights = models.Inception_V3_Weights.DEFAULT
model = models.inception_v3(weights=weights, aux_logits=True)
transform = weights.transforms()

In [4]:
full_dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
print("Classes:", full_dataset.classes)

Classes: ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']


In [5]:
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

print("Train size:", len(train_dataset))
print("Test size:", len(test_dataset))

Train size: 96336
Test size: 24085


In [6]:
from torch.utils.data import WeightedRandomSampler
import numpy as np
import torch

# Extract labels from train dataset
train_targets = [full_dataset.samples[i][1] for i in train_dataset.indices]
class_counts = np.bincount(train_targets)

print("Train class counts:", dict(zip(full_dataset.classes, class_counts)))

# Compute class weights
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = [class_weights[label] for label in train_targets]

# Sampler for balanced training
train_sampler = WeightedRandomSampler(weights=sample_weights,
                                      num_samples=len(sample_weights),
                                      replacement=True)


Train class counts: {'Mild Dementia': 11191, 'Moderate Dementia': 5503, 'Non Demented': 61425, 'Very mild Dementia': 18217}


In [7]:
# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [8]:
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [9]:
model.fc = nn.Linear(model.fc.in_features, 4)
if model.AuxLogits is not None:
    model.AuxLogits.fc = nn.Linear(model.AuxLogits.fc.in_features, 4)

model = model.to(DEVICE)
print("Model device:", next(model.parameters()).device)

Model device: mps:0


In [10]:
criterion = nn.CrossEntropyLoss()

In [11]:
from torchinfo import summary
summary(model = model,
        input_size = (32, 3, 224, 224),
        col_names = ["input_size", "output_size", "num_params", "trainable"],
        col_width = 20,
        row_settings = ["var_names"]
       )

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
Inception3 (Inception3)                  [32, 3, 224, 224]    [32, 4]              2,560,772            True
├─BasicConv2d (Conv2d_1a_3x3)            [32, 3, 224, 224]    [32, 32, 111, 111]   --                   True
│    └─Conv2d (conv)                     [32, 3, 224, 224]    [32, 32, 111, 111]   864                  True
│    └─BatchNorm2d (bn)                  [32, 32, 111, 111]   [32, 32, 111, 111]   64                   True
├─BasicConv2d (Conv2d_2a_3x3)            [32, 32, 111, 111]   [32, 32, 109, 109]   --                   True
│    └─Conv2d (conv)                     [32, 32, 111, 111]   [32, 32, 109, 109]   9,216                True
│    └─BatchNorm2d (bn)                  [32, 32, 109, 109]   [32, 32, 109, 109]   64                   True
├─BasicConv2d (Conv2d_2b_3x3)            [32, 32, 109, 109]   [32, 64, 109, 109]   --                   True
│    └─Conv2d 

In [12]:
def train_one_epoch(loader, optimizer):
    model.train()
    total, correct, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad(set_to_none=True)
        out = model(x)
        if hasattr(out, "logits"):  # InceptionOutputs
            main_logits, aux_logits = out.logits, out.aux_logits
            loss = criterion(main_logits, y)
            if aux_logits is not None:
                loss += 0.4 * criterion(aux_logits, y)
            preds = main_logits.argmax(1)
        else:  # Other models
            loss = criterion(out, y)
            preds = out.argmax(1)
        loss.backward()
        optimizer.step()
        loss_sum += loss.item() * x.size(0)
        correct += (preds == y).sum().item()
        total += y.size(0)
    return loss_sum / total, correct / total

@torch.no_grad()
def evaluate(loader):
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        out = model(x)
        logits = out.logits if hasattr(out, "logits") else out
        loss = criterion(logits, y)
        preds = logits.argmax(1)
        loss_sum += loss.item() * x.size(0)
        correct += (preds == y).sum().item()
        total += y.size(0)
    return loss_sum / total, correct / total



In [13]:
# --------------------------
# Phase 1: Head-only training
# --------------------------
for p in model.parameters():
    p.requires_grad = False
for p in model.fc.parameters():
    p.requires_grad = True
if model.AuxLogits is not None:
    for p in model.AuxLogits.parameters():
        p.requires_grad = True

In [16]:
model = model.to(DEVICE)

In [17]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

In [18]:
import time
for epoch in range(1, 6):
    t0 = time.time()
    tr_loss, tr_acc = train_one_epoch(train_loader, optimizer)
    val_loss, val_acc = evaluate(test_loader)
    dt = (time.time() - t0) / 60
    print(f"Epoch {epoch:02d} | {dt:.1f} min  "
          f"train {tr_loss:.4f}/{tr_acc*100:.2f}%  "
          f"val {val_loss:.4f}/{val_acc*100:.2f}%")

Epoch 01 | 45.6 min  train 1.1466/57.16%  val 0.9099/61.93%
Epoch 02 | 38.6 min  train 1.0481/58.42%  val 0.7798/68.31%
Epoch 03 | 52.2 min  train 1.0225/58.46%  val 0.8468/66.16%
Epoch 04 | 38.2 min  train 1.0138/58.38%  val 0.8763/64.36%
Epoch 05 | 39.0 min  train 1.0094/58.23%  val 0.7882/68.01%


In [19]:
torch.save(model.state_dict(), 'inceptionV3_alzheimers_extraData.pth')