# Project 1: Brain Tumor Segmentation and Grade Classification

This notebook implements:
1. **2D U-Net** for multi-modal brain tumor segmentation (T1, T2, T1+T2) - slice-based processing
2. **Grade Classification** (HGG vs LGG) using 3D CNN for full-volume classification


## 1. Imports and Setup


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import nibabel as nib
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import (classification_report, confusion_matrix, 
                           roc_auc_score, accuracy_score, f1_score,
                           precision_score, recall_score)
from pathlib import Path
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## 2. Data Configuration


In [None]:
# Data paths
DATA_DIR = Path(r"E:\Brain Tumor Segmentation\archive\3D Slices Sorted")
GRADE_MAPPING_FILE = Path(r"E:\Brain Tumor Segmentation\archive\BraTS2020_training_data\content\data\name_mapping.csv")

# Training parameters
SEGMENTATION_EPOCHS = 30
CLASSIFICATION_EPOCHS = 30
BATCH_SIZE_SEG = 16  # Larger batch size for 2D slices
BATCH_SIZE_CLS = 4
LEARNING_RATE = 1e-3  # Higher learning rate for 2D U-Net
TEST_SIZE = 0.2  # 80% train, 20% test
RANDOM_STATE = 42


## 3. Model Architectures

### 3.1 Segmentation Model: 2D U-Net


In [None]:
class UNet2D(nn.Module):
    """2D U-Net optimized for slice-based brain tumor segmentation."""
    
    def __init__(self, in_channels=1, num_classes=4, base_features=32):
        super(UNet2D, self).__init__()
        
        # Encoder
        self.enc1 = self._conv_block(in_channels, base_features)
        self.enc2 = self._conv_block(base_features, base_features * 2)
        self.enc3 = self._conv_block(base_features * 2, base_features * 4)
        self.enc4 = self._conv_block(base_features * 4, base_features * 8)
        
        # Bottleneck
        self.bottleneck = self._conv_block(base_features * 8, base_features * 16)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(base_features * 16, base_features * 8, 2, 2)
        self.dec4 = self._conv_block(base_features * 16, base_features * 8)
        
        self.up3 = nn.ConvTranspose2d(base_features * 8, base_features * 4, 2, 2)
        self.dec3 = self._conv_block(base_features * 8, base_features * 4)
        
        self.up2 = nn.ConvTranspose2d(base_features * 4, base_features * 2, 2, 2)
        self.dec2 = self._conv_block(base_features * 4, base_features * 2)
        
        self.up1 = nn.ConvTranspose2d(base_features * 2, base_features, 2, 2)
        self.dec1 = self._conv_block(base_features * 2, base_features)
        
        # Final classification
        self.final = nn.Conv2d(base_features, num_classes, kernel_size=1)
        
        # Max pooling
        self.pool = nn.MaxPool2d(2)
        
        # Dropout for regularization
        self.dropout = nn.Dropout2d(0.1)
        
    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1)
        )
        
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Bottleneck
        b = self.bottleneck(self.pool(e4))
        
        # Decoder
        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        
        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        # Final prediction
        out = self.final(d1)
        
        return out


### 3.2 Classification Model: Grade Classifier


In [None]:
class TumorGradeClassifier(nn.Module):
    """Brain tumor grade classifier using transfer learning from segmentation model."""
    
    def __init__(self, num_classes=2, dropout_rate=0.5):
        super(TumorGradeClassifier, self).__init__()
        
        # Feature extraction layers (inspired by segmentation model)
        self.feature_extractor = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),  # 240x240x155 -> 120x120x77
            
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),  # 120x120x77 -> 60x60x38
            
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),  # 60x60x38 -> 30x30x19
            
            nn.Conv3d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm3d(256),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool3d((4, 4, 4))  # Global average pooling
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4 * 4, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier(features)
        return output


## 4. Loss Functions


In [None]:
class DiceLoss(nn.Module):
    """Dice Loss for multi-class segmentation (2D)."""
    
    def __init__(self, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()
        
        intersection = (pred * target_one_hot).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

class CombinedLoss(nn.Module):
    """Combined Dice and Cross-Entropy Loss."""
    
    def __init__(self, dice_weight=0.5, ce_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.dice_loss = DiceLoss()
        self.ce_loss = nn.CrossEntropyLoss()
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        
    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        ce = self.ce_loss(pred, target)
        return self.dice_weight * dice + self.ce_weight * ce

class WeightedFocalLoss(nn.Module):
    """Weighted Focal Loss for class imbalance."""
    
    def __init__(self, class_weights, alpha=1, gamma=2):
        super(WeightedFocalLoss, self).__init__()
        self.class_weights = torch.FloatTensor(class_weights)
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, pred, target):
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        # Apply class weights
        if self.class_weights.device != pred.device:
            self.class_weights = self.class_weights.to(pred.device)
        weights = self.class_weights[target]
        weighted_focal_loss = weights * focal_loss
        
        return weighted_focal_loss.mean()


## 5. Dataset Classes


In [None]:
class SliceBasedDataset(Dataset):
    """Memory-efficient slice-based dataset for 2D brain tumor segmentation."""
    
    def __init__(self, data_dir, patient_ids, modalities=['T1'], transform=None):
        self.data_dir = Path(data_dir)
        self.patient_ids = patient_ids
        self.modalities = modalities
        self.transform = transform
        self.samples = self._create_samples()
        
    def _create_samples(self):
        """Create list of samples (patient_id, slice_index)."""
        samples = []
        for patient_id in self.patient_ids:
            # Load one image to get dimensions
            img_path = self.data_dir / f"BraTS20_Training_{patient_id:03d}_{self.modalities[0]}.nii.gz"
            if img_path.exists():
                img = nib.load(img_path).get_fdata()
                num_slices = img.shape[2]
                
                # Create samples for every slice
                for i in range(num_slices):
                    samples.append((patient_id, i))
        
        return samples
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        patient_id, slice_idx = self.samples[idx]
        
        # Load image modalities for the specific slice
        images = []
        for modality in self.modalities:
            img_path = self.data_dir / f"BraTS20_Training_{patient_id:03d}_{modality}.nii.gz"
            if img_path.exists():
                img = nib.load(img_path).get_fdata()
                # Extract single slice
                img_slice = img[:, :, slice_idx]
                # Normalize
                img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min() + 1e-8)
                images.append(img_slice)
            else:
                images.append(np.zeros((240, 240)))
        
        # Stack modalities
        image = np.stack(images, axis=0)  # Shape: (n_modalities, H, W)
        
        # Load segmentation mask for the same slice
        mask = self.load_segmentation_mask(patient_id, slice_idx)
        
        # Apply transforms
        if self.transform:
            image, mask = self.transform(image, mask)
        
        return torch.FloatTensor(image), torch.LongTensor(mask)
    
    def load_segmentation_mask(self, patient_id, slice_idx):
        """Load segmentation mask for specific slice."""
        mask_path = self.data_dir / "masks" / f"BraTS20_Training_{patient_id:03d}_mask.nii.gz"
        if mask_path.exists():
            mask = nib.load(mask_path).get_fdata()
            mask_slice = mask[:, :, slice_idx]
            return mask_slice.astype(np.uint8)
        else:
            return np.zeros((240, 240), dtype=np.uint8)

class BrainTumorClassificationDataset(Dataset):
    """Dataset for brain tumor grade classification."""
    
    def __init__(self, data_dir, grade_mapping, patient_ids, transform=None):
        self.data_dir = Path(data_dir)
        self.grade_mapping = grade_mapping
        self.patient_ids = patient_ids
        self.transform = transform
        
    def __len__(self):
        return len(self.patient_ids)
    
    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        
        # Get the grade for this patient
        patient_name = f"BraTS20_Training_{patient_id:03d}"
        grade = self.grade_mapping[patient_name]
        label = 1 if grade == 'HGG' else 0  # HGG=1, LGG=0
        
        # Load T2 image (best modality from segmentation)
        img_path = self.data_dir / f"BraTS20_Training_{patient_id:03d}_T2.nii.gz"
        
        if img_path.exists():
            img = nib.load(img_path).get_fdata()
            # Normalize the image
            img = (img - img.min()) / (img.max() - img.min() + 1e-8)
            
            # Convert to tensor and add channel dimension
            img = torch.FloatTensor(img).unsqueeze(0)  # Shape: (1, H, W, D)
            
            # Apply transforms if provided
            if self.transform:
                img = self.transform(img)
            
            return img, torch.LongTensor([label])
        else:
            return torch.zeros((1, 240, 240, 155)), torch.LongTensor([label])


## 6. Evaluation Metrics


In [None]:
def calculate_dice_scores(pred, target, num_classes=4):
    """Calculate Dice score for each class."""
    pred = F.softmax(pred, dim=1)
    pred_classes = torch.argmax(pred, dim=1)
    
    dice_scores = []
    for i in range(num_classes):
        pred_i = (pred_classes == i).float()
        target_i = (target == i).float()
        
        intersection = (pred_i * target_i).sum()
        union = pred_i.sum() + target_i.sum()
        
        if union > 0:
            dice = (2.0 * intersection) / union
        else:
            dice = torch.tensor(1.0 if pred_i.sum() == 0 else 0.0)
        
        dice_scores.append(dice.item())
    
    return dice_scores


# PART 1: SEGMENTATION

## 7. Data Preparation for Segmentation


In [None]:
def prepare_segmentation_data(data_dir, modalities=['T1'], test_size=0.2, random_state=42):
    """Prepare dataset for segmentation training (2D slice-based)."""
    data_dir = Path(data_dir)
    
    # Get all patient IDs
    patient_ids = []
    for modality in modalities:
        modality_files = list(data_dir.glob(f"*_{modality}.nii.gz"))
        for file in modality_files:
            patient_id = int(file.name.split('_')[2])
            if patient_id not in patient_ids:
                patient_ids.append(patient_id)
    
    patient_ids.sort()
    
    # Split data (80% train, 20% test)
    train_ids, test_ids = train_test_split(patient_ids, test_size=test_size, random_state=random_state)
    
    # Create datasets (slice-based for 2D U-Net)
    train_dataset = SliceBasedDataset(data_dir, train_ids, modalities)
    test_dataset = SliceBasedDataset(data_dir, test_ids, modalities)
    
    print(f"Found {len(patient_ids)} patients")
    print(f"Training samples: {len(train_dataset)} slices")
    print(f"Test samples: {len(test_dataset)} slices")
    
    return train_dataset, test_dataset, train_ids, test_ids


## 8. Segmentation Training


In [None]:
def train_segmentation_model(modalities, train_dataset, test_dataset, epochs=30, batch_size=16, lr=1e-3):
    """Train segmentation model for given modalities (2D U-Net)."""
    print(f"\n{'='*60}")
    print(f"Training 2D U-Net segmentation model for modalities: {modalities}")
    print(f"{'='*60}")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Initialize model (2D U-Net)
    model_name = '+'.join(modalities)
    model = UNet2D(in_channels=len(modalities), num_classes=4).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    criterion = CombinedLoss()
    
    # Training tracking
    train_losses = []
    val_losses = []
    val_dice_scores = []
    best_dice = 0
    best_model_state = None
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        for batch_idx, (images, masks) in enumerate(train_loader):
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0
        dice_scores_epoch = []
        
        with torch.no_grad():
            for images, masks in test_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
                # Calculate Dice scores
                dice_scores = calculate_dice_scores(outputs, masks)
                dice_scores_epoch.append(dice_scores)
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(test_loader)
        avg_dice = np.mean([np.mean(dice) for dice in dice_scores_epoch])
        
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        val_dice_scores.append(avg_dice)
        
        scheduler.step(avg_val_loss)
        
        print(f'Epoch {epoch+1}/{epochs}: Train Loss: {avg_train_loss:.4f}, '
              f'Val Loss: {avg_val_loss:.4f}, Val Dice: {avg_dice:.4f}')
        
        # Save best model
        if avg_dice > best_dice:
            best_dice = avg_dice
            best_model_state = model.state_dict().copy()
    
    print(f"\nBest Dice Score for {modalities}: {best_dice:.4f}")
    
    return {
        'model': model,
        'best_model_state': best_model_state,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_dice_scores': val_dice_scores,
        'best_dice': best_dice
    }


In [None]:
# Define modality combinations to test
modality_combinations = [
    ['T1'],
    ['T2'],
    ['T1', 'T2']
]

# Store results
segmentation_results = {}

# Train models for each combination
for modalities in modality_combinations:
    # Prepare data
    train_dataset, test_dataset, train_ids, test_ids = prepare_segmentation_data(
        DATA_DIR, modalities, test_size=TEST_SIZE, random_state=RANDOM_STATE
    )
    
    # Train model
    result = train_segmentation_model(
        modalities, train_dataset, test_dataset, 
        epochs=SEGMENTATION_EPOCHS, batch_size=BATCH_SIZE_SEG, lr=LEARNING_RATE
    )
    
    model_name = '+'.join(modalities)
    segmentation_results[model_name] = result


## 10. Segmentation Results Summary


In [None]:
print("\n" + "="*60)
print("SEGMENTATION RESULTS SUMMARY")
print("="*60)

for model_name, results in segmentation_results.items():
    print(f"{model_name}: Best Dice Score = {results['best_dice']:.4f}")

# Find best model
best_seg_model = max(segmentation_results.items(), key=lambda x: x[1]['best_dice'])
print(f"\nBest performing model: {best_seg_model[0]} with Dice Score: {best_seg_model[1]['best_dice']:.4f}")


# PART 2: CLASSIFICATION

## 11. Data Preparation for Classification


In [None]:
# Load grade mapping
def load_grade_mapping(grade_mapping_file):
    """Load grade mapping from CSV file."""
    df = pd.read_csv(grade_mapping_file)
    grade_mapping = {}
    for _, row in df.iterrows():
        patient_name = row['BraTS_2020_subject_ID']
        grade = row['Grade']
        grade_mapping[patient_name] = grade
    return grade_mapping

grade_mapping = load_grade_mapping(GRADE_MAPPING_FILE)
print(f"Loaded grade information for {len(grade_mapping)} patients")

# Prepare classification data
def prepare_classification_data(data_dir, grade_mapping, test_size=0.2, random_state=42):
    """Prepare dataset for classification."""
    data_dir = Path(data_dir)
    
    # Get all patient IDs with grade information
    patient_ids = []
    t2_files = list(data_dir.glob("*_T2.nii.gz"))
    
    for file in t2_files:
        try:
            patient_id = int(file.name.split('_')[2])
            patient_name = f"BraTS20_Training_{patient_id:03d}"
            
            if patient_name in grade_mapping:
                patient_ids.append(patient_id)
        except:
            continue
    
    patient_ids.sort()
    print(f"Found {len(patient_ids)} patients with grade information")
    
    # Get labels for stratification
    labels = []
    for patient_id in patient_ids:
        patient_name = f"BraTS20_Training_{patient_id:03d}"
        grade = grade_mapping[patient_name]
        label = 1 if grade == 'HGG' else 0
        labels.append(label)
    
    # Check class distribution
    label_counts = Counter(labels)
    print(f"Class distribution: HGG={label_counts[1]}, LGG={label_counts[0]}")
    
    # Stratified split (80% train, 20% test)
    train_ids, test_ids, train_labels, test_labels = train_test_split(
        patient_ids, labels, test_size=test_size, random_state=random_state, 
        stratify=labels
    )
    
    return train_ids, test_ids, train_labels, test_labels

train_ids, test_ids, train_labels, test_labels = prepare_classification_data(
    DATA_DIR, grade_mapping, test_size=TEST_SIZE, random_state=RANDOM_STATE
)
print(f"Training samples: {len(train_ids)}")
print(f"Test samples: {len(test_ids)}")


## 12. Classification Training


In [None]:
def calculate_class_weights(labels):
    """Calculate class weights for loss function."""
    class_counts = Counter(labels)
    total_samples = len(labels)
    
    weights = []
    for class_id in sorted(class_counts.keys()):
        weight = total_samples / (len(class_counts) * class_counts[class_id])
        weights.append(weight)
    
    return weights

def create_weighted_sampler(labels):
    """Create weighted sampler for class imbalance."""
    class_counts = Counter(labels)
    total_samples = len(labels)
    
    weights = []
    for label in labels:
        weight = total_samples / (len(class_counts) * class_counts[label])
        weights.append(weight)
    
    return WeightedRandomSampler(weights, len(weights))

def train_classification_model(train_ids, test_ids, train_labels, test_labels, 
                               epochs=30, batch_size=4, lr=1e-4):
    """Train the classification model."""
    print(f"\n{'='*80}")
    print(f"Training Brain Tumor Grade Classification Model")
    print(f"{'='*80}")
    
    # Create datasets
    train_dataset = BrainTumorClassificationDataset(DATA_DIR, grade_mapping, train_ids)
    test_dataset = BrainTumorClassificationDataset(DATA_DIR, grade_mapping, test_ids)
    
    # Create weighted sampler for training
    train_sampler = create_weighted_sampler(train_labels)
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        sampler=train_sampler,
        num_workers=2,
        pin_memory=True if device.type == 'cuda' else False
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=2,
        pin_memory=True if device.type == 'cuda' else False
    )
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    # Initialize model
    model = TumorGradeClassifier(num_classes=2, dropout_rate=0.5).to(device)
    
    # Calculate class weights
    class_weights = calculate_class_weights(train_labels)
    print(f"Class weights: {class_weights}")
    
    # Optimizer and loss function
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
    criterion = WeightedFocalLoss(class_weights=class_weights, alpha=1, gamma=2)
    
    # Training tracking
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    best_f1 = 0
    best_model_state = None
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device).squeeze()
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        all_predictions = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device).squeeze()
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                all_predictions.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        # Calculate metrics
        train_accuracy = 100 * train_correct / train_total
        val_accuracy = 100 * val_correct / val_total
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(test_loader)
        
        # Calculate F1 score
        f1 = f1_score(all_labels, all_predictions, average='weighted')
        
        # Update scheduler
        scheduler.step()
        
        # Store metrics
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)
        
        print(f'Epoch {epoch+1}/{epochs}: Train Loss: {avg_train_loss:.4f} | '
              f'Train Acc: {train_accuracy:.2f}% | Val Loss: {avg_val_loss:.4f} | '
              f'Val Acc: {val_accuracy:.2f}% | F1: {f1:.4f}')
        
        # Save best model based on F1 score
        if f1 > best_f1:
            best_f1 = f1
            best_model_state = model.state_dict().copy()
    
    print(f"\nBest F1 Score: {best_f1:.4f}")
    
    return {
        'model': model,
        'best_model_state': best_model_state,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accuracies': train_accuracies,
        'val_accuracies': val_accuracies,
        'best_f1': best_f1
    }


In [None]:
# Train classification model
classification_result = train_classification_model(
    train_ids, test_ids, train_labels, test_labels,
    epochs=CLASSIFICATION_EPOCHS, batch_size=BATCH_SIZE_CLS, lr=LEARNING_RATE
)


## 14. Classification Evaluation and Results


In [None]:
# Final evaluation
def evaluate_classification_model(model, test_ids, grade_mapping):
    """Comprehensive model evaluation."""
    test_dataset = BrainTumorClassificationDataset(DATA_DIR, grade_mapping, test_ids)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE_CLS, shuffle=False, num_workers=2)
    
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device).squeeze()
            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities[:, 1].cpu().numpy())  # HGG probabilities
    
    # Calculate comprehensive metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions, average='weighted')
    recall = recall_score(all_labels, all_predictions, average='weighted')
    f1 = f1_score(all_labels, all_predictions, average='weighted')
    
    try:
        auc_roc = roc_auc_score(all_labels, all_probabilities)
    except:
        auc_roc = 0.0
    
    cm = confusion_matrix(all_labels, all_predictions)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'auc_roc': auc_roc,
        'confusion_matrix': cm,
        'classification_report': classification_report(all_labels, all_predictions)
    }

# Evaluate model
final_metrics = evaluate_classification_model(
    classification_result['model'], test_ids, grade_mapping
)

# Print results
print("\n" + "="*60)
print("CLASSIFICATION RESULTS SUMMARY")
print("="*60)
print(f"Accuracy: {final_metrics['accuracy']:.4f}")
print(f"Precision: {final_metrics['precision']:.4f}")
print(f"Recall: {final_metrics['recall']:.4f}")
print(f"F1 Score: {final_metrics['f1_score']:.4f}")
print(f"AUC-ROC: {final_metrics['auc_roc']:.4f}")
print(f"\nConfusion Matrix:")
print(final_metrics['confusion_matrix'])
print(f"\nClassification Report:")
print(final_metrics['classification_report'])


## 15. Final Summary


In [None]:
print("\n" + "="*80)
print("FINAL PROJECT SUMMARY")
print("="*80)

print("\n--- SEGMENTATION RESULTS ---")
for model_name, results in segmentation_results.items():
    print(f"{model_name}: Best Dice Score = {results['best_dice']:.4f}")

best_seg_model = max(segmentation_results.items(), key=lambda x: x[1]['best_dice'])
print(f"\nBest Segmentation Model: {best_seg_model[0]} (Dice: {best_seg_model[1]['best_dice']:.4f})")

print("\n--- CLASSIFICATION RESULTS ---")
print(f"Accuracy: {final_metrics['accuracy']:.4f}")
print(f"F1 Score: {final_metrics['f1_score']:.4f}")
print(f"AUC-ROC: {final_metrics['auc_roc']:.4f}")

print("\n" + "="*80)
print("Project 1 Complete!")
