# Brain MRI 3D-CNN Training
Upload your brain_mri_data folder and mri_dataset.csv to Google Drive first

In [None]:
import torch
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install nibabel -q

In [None]:
import os
DRIVE_FOLDER = '/content/drive/MyDrive/Thesis'
DATA_DIR = os.path.join(DRIVE_FOLDER, 'brain_mri_data')
CSV_PATH = os.path.join(DRIVE_FOLDER, 'mri_dataset.csv')
MODEL_PATH = os.path.join(DRIVE_FOLDER, 'brain_model.pth')
print(f'Data exists: {os.path.exists(DATA_DIR)}')
print(f'CSV exists: {os.path.exists(CSV_PATH)}')

In [None]:
import numpy as np
import pandas as pd
import nibabel as nib
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from scipy import ndimage
import warnings
warnings.filterwarnings('ignore')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class BrainMRIDataset(Dataset):
    def __init__(self, data_dir, csv_path, augment=False, target_shape=(128, 128, 128)):
        self.data_dir = data_dir
        self.augment = augment
        self.target_shape = target_shape
        self.file_paths = []
        self.labels = []
        df = pd.read_csv(csv_path)
        print(f'CSV has {len(df)} rows')
        print(f'Target shape for all images: {target_shape}')
        matched = 0
        for f in os.listdir(data_dir):
            if f.endswith('.nii') or f.endswith('.nii.gz'):
                parts = f.split('_')
                if len(parts) >= 2:
                    age_from_file = int(parts[0])
                    gender_from_file = parts[1]
                    tissue = 'GM' if 'wrp1' in f else 'WM'
                    gender_full = 'Female' if gender_from_file == 'F' else 'Male'
                    match = df[(df['age'] == age_from_file) & (df['gender'] == gender_full) & (df['tissue'] == tissue)]
                    if not match.empty:
                        row = match.iloc[0]
                        age = row['age'] / 100.0
                        sex = 0 if row['gender'] == 'Female' else 1
                        tissue_label = 0 if tissue == 'GM' else 1
                        self.file_paths.append(os.path.join(data_dir, f))
                        self.labels.append([age, sex, tissue_label])
                        matched += 1
        print(f'Matched {matched} out of {len(os.listdir(data_dir))} files')
        print(f'Found {len(self.file_paths)} samples')
    def __len__(self):
        return len(self.file_paths)
    def __getitem__(self, idx):
        img = nib.load(self.file_paths[idx])
        data = img.get_fdata().astype(np.float32)
        original_shape = data.shape
        zoom_factors = [t / o for t, o in zip(self.target_shape, original_shape)]
        data = ndimage.zoom(data, zoom_factors, order=1)
        data = np.clip(data, np.percentile(data, 1), np.percentile(data, 99))
        data_min = data.min()
        data_max = data.max()
        if data_max > data_min:
            data = (data - data_min) / (data_max - data_min)
        else:
            data = data - data_min
        if self.augment and np.random.rand() > 0.5:
            angle = np.random.uniform(-5, 5)
            data = ndimage.rotate(data, angle, axes=(0, 1), reshape=False, order=1)
            if np.random.rand() > 0.5:
                data = np.flip(data, axis=0).copy()
        data = np.expand_dims(data, axis=0)
        return torch.FloatTensor(data), torch.FloatTensor(self.labels[idx])

In [None]:
class BrainCNN3D(nn.Module):
    def __init__(self, dropout_rate=0.3):
        super(BrainCNN3D, self).__init__()
        self.conv1 = nn.Conv3d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(32)
        self.pool1 = nn.MaxPool3d(2)
        self.conv2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(64)
        self.pool2 = nn.MaxPool3d(2)
        self.conv3 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm3d(128)
        self.pool3 = nn.MaxPool3d(2)
        self.conv4 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm3d(256)
        self.pool4 = nn.MaxPool3d(2)
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.fc1 = nn.Linear(256, 128)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.age_head = nn.Linear(128, 1)
        self.sex_head = nn.Linear(128, 1)
        self.tissue_head = nn.Linear(128, 1)
    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = self.pool4(F.relu(self.bn4(self.conv4(x))))
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        age = self.age_head(x)
        sex = self.sex_head(x)
        tissue = self.tissue_head(x)
        return age, sex, tissue

In [None]:
EPOCHS = 200
BATCH_SIZE = 4
LR = 0.0001
TARGET_SHAPE = (128, 128, 128)
print('Creating datasets...')
train_full = BrainMRIDataset(DATA_DIR, CSV_PATH, augment=True, target_shape=TARGET_SHAPE)
val_full = BrainMRIDataset(DATA_DIR, CSV_PATH, augment=False, target_shape=TARGET_SHAPE)
train_size = int(0.8 * len(train_full))
indices = list(range(len(train_full)))
np.random.seed(42)
np.random.shuffle(indices)
train_idx = indices[:train_size]
val_idx = indices[train_size:]
train_ds = torch.utils.data.Subset(train_full, train_idx)
val_ds = torch.utils.data.Subset(val_full, val_idx)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
print(f'Train: {len(train_ds)}, Val: {len(val_ds)}')

In [None]:
model = BrainCNN3D(dropout_rate=0.3).to(DEVICE)
age_crit = nn.MSELoss()
sex_crit = nn.BCEWithLogitsLoss()
tissue_crit = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)
print('Model ready')

In [None]:
from tqdm.notebook import tqdm
best_loss = float('inf')
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for batch_data, batch_labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}', leave=False):
        batch_data = batch_data.to(DEVICE)
        batch_labels = batch_labels.to(DEVICE)
        optimizer.zero_grad()
        age_pred, sex_pred, tissue_pred = model(batch_data)
        age_loss = age_crit(age_pred.squeeze(), batch_labels[:, 0])
        sex_loss = sex_crit(sex_pred.squeeze(), batch_labels[:, 1])
        tissue_loss = tissue_crit(tissue_pred.squeeze(), batch_labels[:, 2])
        loss = age_loss + sex_loss + tissue_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)
    model.eval()
    val_loss = 0.0
    age_errors = []
    sex_correct = 0
    tissue_correct = 0
    total = 0
    with torch.no_grad():
        for batch_data, batch_labels in val_loader:
            batch_data = batch_data.to(DEVICE)
            batch_labels = batch_labels.to(DEVICE)
            age_pred, sex_pred, tissue_pred = model(batch_data)
            age_loss = age_crit(age_pred.squeeze(), batch_labels[:, 0])
            sex_loss = sex_crit(sex_pred.squeeze(), batch_labels[:, 1])
            tissue_loss = tissue_crit(tissue_pred.squeeze(), batch_labels[:, 2])
            loss = age_loss + sex_loss + tissue_loss
            val_loss += loss.item()
            age_errors.extend((torch.abs(age_pred.squeeze() - batch_labels[:, 0]) * 100).cpu().numpy())
            sex_correct += ((torch.sigmoid(sex_pred.squeeze()) > 0.5) == batch_labels[:, 1]).sum().item()
            tissue_correct += ((torch.sigmoid(tissue_pred.squeeze()) > 0.5) == batch_labels[:, 2]).sum().item()
            total += batch_labels.size(0)
    val_loss /= len(val_loader)
    age_mae = np.mean(age_errors)
    sex_acc = 100 * sex_correct / total
    tissue_acc = 100 * tissue_correct / total
    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f'Epoch {epoch+1}/{EPOCHS}: Loss={val_loss:.4f}, Age MAE={age_mae:.2f}, Sex={sex_acc:.1f}%, Tissue={tissue_acc:.1f}%')
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), MODEL_PATH)
        if (epoch + 1) % 10 == 0:
            print('  Best model saved!')
    scheduler.step()
print(f'Training complete! Best loss: {best_loss:.4f}')

In [None]:
from google.colab import files
files.download(MODEL_PATH)