In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms


In [2]:
DATA_DIR = "/kaggle/input/oasis-dataset/Oasis_dataset"

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else
                      ("mps" if torch.backends.mps.is_available() else "cpu"))

In [4]:
weights = models.ResNet152_Weights.IMAGENET1K_V1
model = models.resnet152(weights=weights)
transform = weights.transforms()

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:01<00:00, 215MB/s] 


In [5]:
full_dataset = datasets.ImageFolder(DATA_DIR, transform=transform)

In [6]:
from torch.utils.data import random_split, DataLoader
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

print("Train size:", len(train_dataset))
print("Test size:", len(test_dataset))

Train size: 96336
Test size: 24085


In [7]:
from torch.utils.data import WeightedRandomSampler
import numpy as np
import torch

# Extract labels from train dataset
train_targets = [full_dataset.samples[i][1] for i in train_dataset.indices]
class_counts = np.bincount(train_targets)

print("Train class counts:", dict(zip(full_dataset.classes, class_counts)))

# Compute class weights
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = [class_weights[label] for label in train_targets]

# Sampler for balanced training
train_sampler = WeightedRandomSampler(weights=sample_weights,
                                      num_samples=len(sample_weights),
                                      replacement=True)


Train class counts: {'Mild Dementia': 11136, 'Moderate Dementia': 5534, 'Non Demented': 61579, 'Very mild Dementia': 18087}


In [8]:
# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [9]:
num_classes = 4   # change this to match your dataset
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

In [10]:
for param in model.parameters():
    param.requires_grad = False   # freeze all layers

# Unfreeze last few layers + final fc
for name, param in list(model.named_parameters())[-20:]:  # last ~20 params ≈ last 4-5 layers
    param.requires_grad = True

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
model.to(DEVICE)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [12]:
from torchinfo import summary
summary(model = model,
        input_size = (32, 3, 224, 224),
        col_names = ["input_size", "output_size", "num_params", "trainable"],
        col_width = 20,
        row_settings = ["var_names"]
       )

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
ResNet (ResNet)                          [32, 3, 224, 224]    [32, 4]              --                   Partial
├─Conv2d (conv1)                         [32, 3, 224, 224]    [32, 64, 112, 112]   (9,408)              False
├─BatchNorm2d (bn1)                      [32, 64, 112, 112]   [32, 64, 112, 112]   (128)                False
├─ReLU (relu)                            [32, 64, 112, 112]   [32, 64, 112, 112]   --                   --
├─MaxPool2d (maxpool)                    [32, 64, 112, 112]   [32, 64, 56, 56]     --                   --
├─Sequential (layer1)                    [32, 64, 56, 56]     [32, 256, 56, 56]    --                   False
│    └─Bottleneck (0)                    [32, 64, 56, 56]     [32, 256, 56, 56]    --                   False
│    │    └─Conv2d (conv1)               [32, 64, 56, 56]     [32, 64, 56, 56]     (4,096)              False
│    │    

In [13]:
def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0
    correct = 0
    total = 0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

# Evaluation function
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [14]:
import time
num_epochs = 5

for epoch in range(num_epochs):
    start_time = time.time()

    train_loss, train_acc = train(model, train_loader, criterion, optimizer, DEVICE)
    val_loss, val_acc = evaluate(model, test_loader, criterion, DEVICE)

    elapsed = time.time() - start_time
    print(f"Epoch {epoch+1}/{num_epochs}, Time: {elapsed:.2f}s")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss  : {val_loss:.4f}, Val Acc  : {val_acc:.4f}")

Epoch 1/5, Time: 1264.83s
  Train Loss: 0.3485, Train Acc: 0.8594
  Val Loss  : 0.2621, Val Acc  : 0.8933
Epoch 2/5, Time: 1072.54s
  Train Loss: 0.1410, Train Acc: 0.9477
  Val Loss  : 0.1811, Val Acc  : 0.9290
Epoch 3/5, Time: 1042.65s
  Train Loss: 0.0884, Train Acc: 0.9684
  Val Loss  : 0.1307, Val Acc  : 0.9498
Epoch 4/5, Time: 1009.97s
  Train Loss: 0.0637, Train Acc: 0.9771
  Val Loss  : 0.1283, Val Acc  : 0.9570
Epoch 5/5, Time: 1002.16s
  Train Loss: 0.0494, Train Acc: 0.9822
  Val Loss  : 0.1372, Val Acc  : 0.9556


In [15]:
torch.save(model.state_dict(), 'resnet152_extraData.pth')