In [1]:
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader
import os
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
DATA_DIR = "/Users/rishavghosh/Desktop/python/Oasis_dataset"

In [3]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

In [4]:
full_dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
print("Classes:", full_dataset.classes)

Classes: ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']


In [5]:
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

print("Train size:", len(train_dataset))
print("Test size:", len(test_dataset))

Train size: 69149
Test size: 17288


In [8]:
import numpy as np
train_targets = [full_dataset.samples[i][1] for i in train_dataset.indices]
train_targets
class_counts = np.bincount(train_targets)
class_counts

array([ 3988,   380, 53841, 10940])

In [None]:
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
print("Class weights:", class_weights)
sample_weights = [class_weights[label] for label in train_targets]

In [6]:
from torch.utils.data import WeightedRandomSampler
import numpy as np
import torch

# Extract labels from train dataset
train_targets = [full_dataset.samples[i][1] for i in train_dataset.indices]
class_counts = np.bincount(train_targets)

print("Train class counts:", dict(zip(full_dataset.classes, class_counts)))

# Compute class weights
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = [class_weights[label] for label in train_targets]

# Sampler for balanced training
train_sampler = WeightedRandomSampler(weights=sample_weights,
                                      num_samples=len(sample_weights),
                                      replacement=True)


Train class counts: {'Mild Dementia': 4028, 'Moderate Dementia': 389, 'Non Demented': 53759, 'Very mild Dementia': 10973}


In [7]:
# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [8]:
from collections import Counter

labels_list = []
for _, labels in train_loader:
    labels_list.extend(labels.tolist())

print("Sampled counts:", Counter(labels_list))


Sampled counts: Counter({1: 17301, 3: 17299, 0: 17277, 2: 17272})


In [13]:
class BrainCNN(nn.Module):
    def __init__(self, num_classes):
        super(BrainCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 28 * 28, 256),  
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [17]:
import requests
from pathlib import Path 

# Download helper functions from Learn PyTorch repo (if not already downloaded)
if Path("helper_functions.py").is_file():
  print("helper_functions.py already exists, skipping download")
else:
  print("Downloading helper_functions.py")
  # Note: you need the "raw" GitHub URL for this to work
  request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/helper_functions.py")
  with open("helper_functions.py", "wb") as f:
    f.write(request.content)




Downloading helper_functions.py


In [21]:
import torch.optim as optim
from helper_functions import accuracy_fn

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model_1 = BrainCNN(num_classes=len(full_dataset.classes)).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_1.parameters(), lr=0.0005)


In [15]:
from timeit import default_timer as timer
def print_train_time(start: float, end: float, device: torch.device = None):
    """Prints the training time.

    Args:
        start (float): Time when training started.
        end (float): Time when training ended.
        device (torch.device, optional): Device used for training. Defaults to None.
    """
    total_time = end - start
    print(f"Training time on {device}: {total_time:.3f} seconds")
    return total_time

In [29]:
def train_step(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device = device, accuracy_fn = accuracy_fn):
    train_loss, train_acc = 0, 0
    model.to(device)
    for batch, (X, y) in enumerate(data_loader):
        X, y = X.to(device), y.to(device)
        y_pred = model(X)

        loss = loss_fn(y_pred, y)
        train_loss += loss
        train_acc += accuracy_fn(y_true = y, y_pred = y_pred.argmax(dim=1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_loss /= len(data_loader)
    train_acc /= len(data_loader)
    print(f"Train loss: {train_loss:.5f} | Train accuracy: {train_acc:.2f}%")

def test_step(data_loader: torch.utils.data.DataLoader, model: torch.nn.Module, loss_fn: torch.nn.Module, device: torch.device = device, accuracy_fn = accuracy_fn):
    test_loss, test_acc = 0, 0
    model.to(device)
    with torch.inference_mode(): 
        for X, y in data_loader:
            X, y = X.to(device), y.to(device)
        
            test_pred = model(X)
        
            loss = loss_fn(test_pred, y)
            test_loss += loss.item()   # <--- FIXED
        
            test_acc += accuracy_fn(
                y_true=y,
                y_pred=test_pred.argmax(dim=1)
            )

    test_loss /= len(data_loader)
    test_acc /= len(data_loader)
    print(f"Test loss: {test_loss:.5f} | Test accuracy: {test_acc:.2f}%\n")

In [30]:
torch.manual_seed(42)

# Measure time
from timeit import default_timer as timer
from tqdm.auto import tqdm
train_time_start_on_gpu = timer()

epochs = 3
for epoch in tqdm(range(epochs)):
    print(f"Epoch: {epoch}\n---------")
    train_step(data_loader=train_loader, 
        model=model_1, 
        loss_fn=loss_fn,
        optimizer=optimizer,
        accuracy_fn=accuracy_fn
    )
    test_step(data_loader=test_loader,
        model=model_1,
        loss_fn=loss_fn,
        accuracy_fn=accuracy_fn
    )

train_time_end_on_gpu = timer()
total_train_time_model_1 = print_train_time(start=train_time_start_on_gpu,
                                            end=train_time_end_on_gpu,
                                            device=device)

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch: 0
---------
Train loss: 0.34780 | Train accuracy: 83.28%


 33%|███▎      | 1/3 [10:28<20:57, 628.75s/it]

Test loss: 0.43020 | Test accuracy: 80.57%

Epoch: 1
---------
Train loss: 0.26644 | Train accuracy: 87.18%


 67%|██████▋   | 2/3 [21:21<10:42, 642.65s/it]

Test loss: 0.36189 | Test accuracy: 81.91%

Epoch: 2
---------
Train loss: 0.21339 | Train accuracy: 89.88%


100%|██████████| 3/3 [32:07<00:00, 642.63s/it]

Test loss: 0.37055 | Test accuracy: 84.20%

Training time on mps: 1927.786 seconds





In [31]:
torch.save(model_1.state_dict(), "brainScan_cnnLTImindtree.pth")