In [1]:
import os
import glob
import numpy as np
import pandas as pd
import shutil
import timm

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

from random import shuffle

import torch
import torch.nn as nn
import torchvision.transforms as T
import nibabel as nib
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from PIL import Image

from skimage.transform import resize

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MRIDataset(Dataset):
    def __init__(self, file_paths, labels):
        self.file_paths = file_paths
        self.labels = labels
        self.target_size = 224  # Must match the model's expected input size

    def normalize_slice(self, slice_data):
        # Z-score normalization often works better for medical images
        mean = np.mean(slice_data)
        std = np.std(slice_data)
        return (slice_data - mean) / (std + 1e-8)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        # Load NIfTI file
        img_path = self.file_paths[idx]
        label = self.labels[idx]

        try:
            # Load and preprocess NIfTI image
            nifti_img = nib.load(img_path)
            img_data = nifti_img.get_fdata()

            # Process to get 3-channel 2D representation
            if len(img_data.shape) == 3:
                x_mid = img_data.shape[0] // 2
                y_mid = img_data.shape[1] // 2
                z_mid = img_data.shape[2] // 2

                # Extract slices
                slice1 = img_data[x_mid, :, :]
                slice2 = img_data[:, y_mid, :]
                slice3 = img_data[:, :, z_mid]

                # Apply z-score normalization instead of min-max
                slice1 = self.normalize_slice(slice1)
                slice2 = self.normalize_slice(slice2)
                slice3 = self.normalize_slice(slice3)

                # Resize to exactly 224x224 for ViT models
                slice1_resized = resize(slice1, (self.target_size, self.target_size), anti_aliasing=True)
                slice2_resized = resize(slice2, (self.target_size, self.target_size), anti_aliasing=True)
                slice3_resized = resize(slice3, (self.target_size, self.target_size), anti_aliasing=True)

                # Stack along new channel dimension
                img_array = np.stack([slice1_resized, slice2_resized, slice3_resized], axis=0)

                # Convert to tensor
                img_tensor = torch.from_numpy(img_array).float()

                return img_tensor, torch.tensor(label, dtype=torch.long)
            else:
                raise ValueError(f"Expected 3D NIfTI data, got shape {img_data.shape}")
        except Exception as e:
            print(f"Error processing {img_path}: {e}")
            # Return a placeholder tensor with correct size
            img_tensor = torch.zeros((3, self.target_size, self.target_size), dtype=torch.float32)
            return img_tensor, torch.tensor(label, dtype=torch.long)

In [3]:
# Load DINOv2 model
def load_dinov2_model(variant="vitb14"):
    """Load DINOv2 model for feature extraction"""
    model = torch.hub.load('facebookresearch/dinov2', f'dinov2_{variant}')
    return model

In [4]:
class ADHDClassifier(nn.Module):
    def __init__(self, backbone="resnet18", freeze_backbone=True):
        super(ADHDClassifier, self).__init__()
        self.backbone_name = backbone

        # Map custom backbone names to valid timm model names
        if backbone == "vitb14":
            # Use base model with patch size 16 instead, which has pretrained weights
            timm_model_name = "vit_base_patch16_224"
            use_pretrained = True
        else:
            # For other backbone names, try to use them directly
            timm_model_name = backbone
            use_pretrained = True

        try:
            # Create the backbone using the mapped name
            self.backbone = timm.create_model(timm_model_name, pretrained=use_pretrained)

            # Replace the head of the model
            embed_dim = self.backbone.head.in_features
            self.backbone.head = nn.Identity()  # Remove classification head

            # Create a new classifier head
            self.classifier = nn.Sequential(
                nn.Linear(embed_dim, 512),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(512, 2)  # Binary classification: ADHD vs Control
            )

            # Freeze backbone if requested and using pretrained
            if freeze_backbone and use_pretrained:
                for param in self.backbone.parameters():
                    param.requires_grad = False

        except Exception as e:
            print(f"Error creating model with {timm_model_name}: {e}")
            print("Trying with vit_base_patch16_224 without pretrained weights...")

            # Fallback to vit_base_patch16_224 without pretrained weights
            self.backbone = timm.create_model('vit_base_patch16_224', pretrained=False)
            embed_dim = self.backbone.head.in_features
            self.backbone.head = nn.Identity()

            self.classifier = nn.Sequential(
                nn.Linear(embed_dim, 512),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(512, 2)
            )

            # Don't freeze if using random initialization
            if freeze_backbone:
                print("Note: Not freezing backbone since we're using random initialization")

    def forward(self, x):
        # Add this resize operation before passing to backbone
        if x.shape[-1] != 224 or x.shape[-2] != 224:
            #print(f"Resizing input from {x.shape[-2]}x{x.shape[-1]} to 224x224")
            x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)

        # Extract features
        features = self.backbone(x)

        # Classification
        logits = self.classifier(features)

        return logits

In [5]:
def train_adhd_classifier(model, train_loader, val_loader, epochs=20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Add the scheduler here
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3
    )

    best_val_acc = 0.0
    best_model = None

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Calculate average training loss
        train_loss = train_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}')

        # Validation phase
        model.eval()
        correct = 0
        total = 0
        val_loss = 0.0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * correct / total
        print(f'Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%')

        # Update the scheduler based on validation accuracy
        scheduler.step(val_accuracy)

        # Save the best model
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            best_model = model.state_dict().copy()

    # Load the best model
    if best_model:
        model.load_state_dict(best_model)

    return model

In [6]:
def custom_collate(batch):
    images = []
    labels = []

    for image, label in batch:
        # Ensure image is exactly 256x256
        if image.shape[1] != 256 or image.shape[2] != 256:
            # Resize to 256x256 using F.interpolate
            image = F.interpolate(image.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)

        images.append(image)
        labels.append(label)

    # Stack all images and labels
    images = torch.stack(images)
    labels = torch.stack(labels)

    return images, labels

In [7]:
# Define Transform
#transform = T.Compose([
    # Convert to PIL image first (required for many transforms)
#    T.ToPILImage(),
    # Resize to exactly 256x256
#    T.Resize((256, 256)),
    # Convert back to tensor
#    T.ToTensor(),
#    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#])
# Define transform without resize (we'll handle it in the collate function)
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [8]:
adhd_image_path = "image_data/adhd"
control_image_path = "image_data/control"

adhd_files = [os.path.join(adhd_image_path, file_path) for file_path in os.listdir(adhd_image_path) if file_path.endswith(".nii") or file_path.endswith(".nii.gz")]
control_files = [os.path.join(adhd_image_path, file_path) for file_path in os.listdir(adhd_image_path) if file_path.endswith(".nii") or file_path.endswith(".nii.gz")]

all_files = adhd_files + control_files
labels = [1] * len(adhd_files) + [0] * len(control_files)

# Split dataset
train_files, test_files, train_labels, test_labels = train_test_split(
    all_files, labels, test_size=0.2, random_state=42, stratify=labels
)

In [9]:
# Create datasets and data loaders
train_dataset = MRIDataset(train_files, train_labels)
test_dataset = MRIDataset(test_files, test_labels)

In [10]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)

In [11]:
# Verify the size of a batch
for images, labels in train_loader:
    print(f"Batch shape: {images.shape}")  # Should be [batch_size, 3, 224, 224]
    break

Batch shape: torch.Size([16, 3, 224, 224])


In [12]:
# Initialize and train model
model = ADHDClassifier(backbone="vitb14", freeze_backbone=True)
trained_model = train_adhd_classifier(model, train_loader, test_loader, epochs=20)

Epoch 1/20, Train Loss: 1.2772
Validation Loss: 1.8824, Accuracy: 48.65%
Epoch 2/20, Train Loss: 1.2883
Validation Loss: 0.7749, Accuracy: 51.35%
Epoch 3/20, Train Loss: 0.7767
Validation Loss: 0.6952, Accuracy: 51.35%
Epoch 4/20, Train Loss: 0.6871
Validation Loss: 0.7303, Accuracy: 45.95%
Epoch 5/20, Train Loss: 0.6924
Validation Loss: 0.7005, Accuracy: 45.95%
Epoch 6/20, Train Loss: 0.6951
Validation Loss: 0.7156, Accuracy: 29.73%
Epoch 7/20, Train Loss: 0.6978
Validation Loss: 0.7126, Accuracy: 29.73%
Epoch 8/20, Train Loss: 0.6890
Validation Loss: 0.7179, Accuracy: 21.62%
Epoch 9/20, Train Loss: 0.6935
Validation Loss: 0.7205, Accuracy: 29.73%
Epoch 10/20, Train Loss: 0.6891
Validation Loss: 0.7249, Accuracy: 21.62%
Epoch 11/20, Train Loss: 0.6888
Validation Loss: 0.7241, Accuracy: 18.92%
Epoch 12/20, Train Loss: 0.6858
Validation Loss: 0.7344, Accuracy: 35.14%
Epoch 13/20, Train Loss: 0.6872
Validation Loss: 0.7239, Accuracy: 32.43%
Epoch 14/20, Train Loss: 0.6853
Validation Loss

In [13]:
# Evaluate model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trained_model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = trained_model(images)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    # Print classification matrix
    print(classification_report(all_labels, all_preds, target_names=["Control", "ADHD"]))
    print("Confusion Matrix:")
    print(confusion_matrix(all_labels, all_preds))

              precision    recall  f1-score   support

     Control       0.17      0.11      0.13        19
        ADHD       0.32      0.44      0.37        18

    accuracy                           0.27        37
   macro avg       0.24      0.27      0.25        37
weighted avg       0.24      0.27      0.25        37

Confusion Matrix:
[[ 2 17]
 [10  8]]
