# Import Library

In [None]:
!pip install torchio

In [None]:
import os
import pandas as pd
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models.video import r3d_18
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error
import math

from tqdm import tqdm

import torchio as T

import warnings
warnings.filterwarnings('ignore')

In [None]:
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)                   # PyTorch CPU
    torch.cuda.manual_seed(seed)              # PyTorch GPU (1 GPU)
    torch.cuda.manual_seed_all(seed)          # PyTorch GPU (all GPUs)
    torch.backends.cudnn.deterministic = True # CUDNN deterministic
    torch.backends.cudnn.benchmark = False    # Turn off speed autotune

set_seed(42)

# Load dataset

In [None]:
class AlzheimerDataset(Dataset):
    def __init__(self, 
                 image_folder,
                 data_csv,
                 transform=None):
        """
        Args:
            image_folder: Path to folder containing image_id subfolders
            data_csv: Path to CSV file containing all data (demographics, cognitive scores, follow-up)
            transform: Optional transforms for MRI images
        """
        
        self.image_folder = image_folder
        self.transform = transform
        
        # Load CSV file
        self.data_df = pd.read_csv(data_csv)
        
        # Process data
        self.data_df = self._process_data()
        
        # Filter valid samples (has image file and all required data)
        self.valid_samples = self._filter_valid_samples()
        
    def _process_data(self):
        """Process the data from CSV file"""
        df = self.data_df.copy()
        
        # Convert gender to numeric
        df['PTGENDER'] = df['PTGENDER'].astype(np.float32)
                
        return df
    
    def _filter_valid_samples(self):
        """Filter samples that have valid image files and complete data"""
        valid_samples = []
        
        # Check if image folder exists
        if not os.path.exists(self.image_folder):
            return valid_samples
        
        for idx, row in self.data_df.iterrows():                
            image_id = row['image_id']
            
            # Construct folder path with "I" prefix
            folder_name = f"I{image_id}"
            image_path = os.path.join(self.image_folder, folder_name)
            
            if not os.path.exists(image_path):
                continue
            
            # Look for the specific file name: T1_biascorr_brain.nii.gz
            nii_file_path = os.path.join(image_path, "T1_biascorr_brain.nii.gz")
            
            if not os.path.exists(nii_file_path):
                # If specific file doesn't exist, look for any .nii or .nii.gz file
                if os.path.exists(image_path):
                    files_in_folder = os.listdir(image_path)
                    nii_files = [f for f in files_in_folder 
                                if f.endswith('.nii') or f.endswith('.nii.gz')]
                    if len(nii_files) == 0:
                        continue
                    nii_file_path = os.path.join(image_path, nii_files[0])
            
            # Check for required data completeness
            required_cols = [
                'PTGENDER', 'age', 'PTEDUCAT', 'time_lapsed', 'DIAGNOSIS_now',
                'ADAS11_future', 'ADAS13_future', 'MMSCORE_future', 'CDGLOBAL_future'
            ]
            missing_cols = [col for col in required_cols if col not in row or pd.isna(row[col])]
            
            if missing_cols:
                continue
            
            # Store file path and index
            valid_samples.append({
                'idx': idx,
                'image_path': nii_file_path,
                'image_id': image_id
            })
        
        return valid_samples
    
    def _load_mri_image(self, image_path):
        """Load and preprocess MRI image"""
        # Load nii file
        img = nib.load(image_path)
        data = img.get_fdata()
        
        # Convert to float32 and add channel dimension
        data = data.astype(np.float32)
        
        # Normalize intensity to [0, 1]
        data = (data - data.min()) / (data.max() - data.min() + 1e-8)
        
        # Convert to tensor: (D, H, W)
        data_tensor = torch.from_numpy(data)
        
        # Repeat single channel to create 3 channels: (3, D, H, W)
        data_tensor = data_tensor.unsqueeze(0).repeat(3, 1, 1, 1)
        
        # Apply transforms if provided
        if self.transform:
            data_tensor = self.transform(data_tensor)
        
        return data_tensor
    
    def __len__(self):
        return len(self.valid_samples)
    
    def __getitem__(self, idx):
        # Get sample info
        sample_info = self.valid_samples[idx]
        df_idx = sample_info['idx']
        image_path = sample_info['image_path']
        
        # Get row data
        row = self.data_df.iloc[df_idx]
        
        # Load MRI image
        mri_image = self._load_mri_image(image_path)

        # Get demographics
        demographics = torch.tensor([
            row['age'],
            row['PTGENDER'],
            row['PTEDUCAT'],
            row['DIAGNOSIS_now'],
            row['time_lapsed']
        ], dtype=torch.float32)
        
        # Get cognitive scores
        scores = {
            'ADAS11_future': row['ADAS11_future'],
            'ADAS13_future': row['ADAS13_future'],
            'MMSCORE_future': row['MMSCORE_future'],
            'CDGLOBAL_future': row['CDGLOBAL_future']
        }
        
        # Convert to tensors
        targets = {
            'adas11_future': torch.tensor(scores['ADAS11_future'], dtype=torch.float32),
            'adas13_future': torch.tensor(scores['ADAS13_future'], dtype=torch.float32),
            'mmscore_future': torch.tensor(scores['MMSCORE_future'], dtype=torch.float32),
            'cdglobal_future': torch.tensor(scores['CDGLOBAL_future'], dtype=torch.float32)
        }
        
        return mri_image, demographics, targets

# Data loading utilities
def create_data_transforms(train=True, image_size=(64, 64, 64)):
    """Create transforms for MRI data"""
    
    if train:
        transform = T.Compose([
            T.Resize(image_size),
            T.RandomFlip(axes=('LR',), p=0.5),
            T.RandomAffine(scales=0.1, degrees=10, translation=5, p=0.75),
            T.RandomNoise(std=0.01, p=0.5),
            T.RandomMotion(p=0.3),
            # T.ZNormalization()
        ])
    else:
        transform = T.Compose([
            T.Resize(image_size),
            # T.ZNormalization()
        ])
    
    return transform

In [None]:
def create_data_loaders(image_folder,
                       data_csv,
                       batch_size=8,
                       val_split=0.1,
                       test_split=0.1,
                       image_size=(64, 64, 64)):
    """
    Create train, validation, and test data loaders

    Args:
        image_folder: Path to MRI images folder
        data_csv: Path to CSV file containing all data
        batch_size: Batch size for training
        val_split: Validation split ratio
        test_split: Test split ratio
        image_size: Target size for MRI images

    Returns:
        train_loader, val_loader, test_loader
    """

    # Create full dataset
    full_dataset = AlzheimerDataset(
        image_folder=image_folder,
        data_csv=data_csv,
        transform=None  # Will be set later
    )

    # Calculate split sizes
    dataset_size = len(full_dataset)
    test_size = int(dataset_size * test_split)
    val_size = int(dataset_size * val_split)
    train_size = dataset_size - val_size - test_size

    # Random split
    train_indices, val_indices, test_indices = torch.utils.data.random_split(
        range(dataset_size), [train_size, val_size, test_size]
    )

    # Create transforms
    train_transform = create_data_transforms(train=True, image_size=image_size)
    val_test_transform = create_data_transforms(train=False, image_size=image_size)

    # Create subset datasets
    train_dataset = torch.utils.data.Subset(full_dataset, train_indices.indices)
    val_dataset = torch.utils.data.Subset(full_dataset, val_indices.indices)
    test_dataset = torch.utils.data.Subset(full_dataset, test_indices.indices)

    # Apply transforms
    full_dataset.transform = train_transform

    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    print(f"Data splits - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    return train_loader, val_loader, test_loader

# Example usage
def load_alzheimer_data(config):
    """
    Load Alzheimer dataset with given configuration

    config should contain:
    - image_folder: path to images
    - data_csv: path to data file
    - batch_size: batch size
    - image_size: tuple of target image dimensions
    """

    train_loader, val_loader, test_loader = create_data_loaders(
        image_folder=config['image_folder'],
        data_csv=config['data_csv'],
        batch_size=config.get('batch_size', 8),
        image_size=config.get('image_size', (64, 64, 64)),
        val_split=config.get('val_split', 0.1),
        test_split=config.get('test_split', 0.1)
    )

    return train_loader, val_loader, test_loader

# Model

In [None]:
# # 3D CNN backbone for MRI images - keep original pretrained weights
# cnn_backbone = r3d_18(pretrained=True)
# # Remove final classification layer
# cnn_backbone.fc = nn.Identity()
# for param in cnn_backbone.parameters():
#     print(param.requires_grad)

In [None]:
class AlzheimerMultiTaskModel(nn.Module):
    def __init__(self, pretrained=True, dropout_rate=0.3):
        super(AlzheimerMultiTaskModel, self).__init__()
        
        # 3D CNN backbone for MRI images - keep original pretrained weights
        self.cnn_backbone = r3d_18(pretrained=pretrained)
        # Remove final classification layer
        self.cnn_backbone.fc = nn.Identity()
           
        # Demographic network
        self.demographic_net = nn.Sequential(
            nn.Linear(5, 64),  # gender(1) + age(1) + education(1) + diagnosis(1) + time_future(1)
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        # Feature fusion layer
        cnn_features = 512        
        demo_features = 128  # demographics future
        fused_features = cnn_features + demo_features

        self.fusion_layer = nn.Sequential(
            nn.Linear(fused_features, 512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        # Multi-task heads for follow-up scores
        self.adas11_future = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        self.adas13_future = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(), 
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        self.mmscore_future = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(), 
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        self.cdglobal_future = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, mri_image, demographics):
        # Extract features from MRI
        mri_features = self.cnn_backbone(mri_image)
        
        # Process demographic data
        demo_features = self.demographic_net(demographics)
        
        fused_features = torch.cat([mri_features, demo_features], dim=1)
        shared_features = self.fusion_layer(fused_features)

        # Follow-up scores predictions
        adas11_future = self.adas11_future(shared_features)
        adas13_future = self.adas13_future(shared_features)
        mmscore_future = self.mmscore_future(shared_features)
        cdglobal_future = self.cdglobal_future(shared_features)
        
        return {
            'adas11_future': adas11_future.squeeze(-1),
            'adas13_future': adas13_future.squeeze(-1),
            'mmscore_future': mmscore_future.squeeze(-1),
            'cdglobal_future': cdglobal_future.squeeze(-1)
        }

class AlzheimerMultiTaskLoss(nn.Module):
    def __init__(self, uncertainty_weighting=False):
        super(AlzheimerMultiTaskLoss, self).__init__()
        
        # Default equal weights for all tasks
        self.task_weights = {
            'adas11_future': 3.0,
            'adas13_future': 3.0,
            'mmscore_future': 2.0,
            'cdglobal_future': 1.0
        }
            
        self.uncertainty_weighting = uncertainty_weighting
        
        # Learnable uncertainty parameters if using uncertainty weighting
        if uncertainty_weighting:
            self.log_vars = nn.Parameter(torch.zeros(8))  # 8 tasks now
        
        # Different loss functions for different score types
        self.mse_loss = nn.MSELoss()
        self.smooth_l1_loss = nn.SmoothL1Loss()
    
    def forward(self, predictions, targets):
        losses = {}
        total_loss = 0

        # Follow-up scores losses with task-specific scaling
        adas11_future_loss = self.mse_loss(predictions['adas11_future'], targets['adas11_future'])
        adas13_future_loss = self.mse_loss(predictions['adas13_future'], targets['adas13_future'])
        mmscore_future_loss = self.smooth_l1_loss(predictions['mmscore_future'], targets['mmscore_future'])
        cdglobal_future_loss = self.smooth_l1_loss(predictions['cdglobal_future'], targets['cdglobal_future'])
        
        task_losses = [
            adas11_future_loss, adas13_future_loss, mmscore_future_loss, cdglobal_future_loss
        ]
        task_names = [
            'adas11_future', 'adas13_future', 'mmscore_future', 'cdglobal_future'
        ]
        
        # Apply task weighting
        if self.uncertainty_weighting:
            # Uncertainty-based multi-task weighting
            for i, (loss, name) in enumerate(zip(task_losses, task_names)):
                precision = torch.exp(-self.log_vars[i])
                weighted_loss = precision * loss + self.log_vars[i]
                losses[f'{name}_loss'] = loss
                total_loss += weighted_loss
        else:
            # Manual task weighting
            for loss, name in zip(task_losses, task_names):
                weighted_loss = self.task_weights[name] * loss
                losses[f'{name}_loss'] = loss
                total_loss += weighted_loss
        
        losses['total_loss'] = total_loss
        return losses

# Training utilities
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        
    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

# Example usage
def create_alzheimer_model(pretrained=True, dropout_rate=0.3, 
                          uncertainty_weighting=True):
    """
    Create Alzheimer prediction model with custom loss
    
    Args:
        pretrained: Use pretrained 3D ResNet backbone
        dropout_rate: Dropout rate for regularization
        uncertainty_weighting: Use learnable uncertainty weighting
    
    Returns:
        model, loss_fn
    """
    model = AlzheimerMultiTaskModel(
        pretrained=pretrained,
        dropout_rate=dropout_rate
    )
    
    loss_fn = AlzheimerMultiTaskLoss(
        uncertainty_weighting=uncertainty_weighting
    )
    
    return model, loss_fn

# Train function

In [None]:
def train_alzheimer_model(model, train_loader, val_loader, num_epochs=200, patience = 10, 
                         learning_rate=0.001, device='cuda'):
    """
    Training function for Alzheimer prediction model
    """
    model, loss_fn = create_alzheimer_model(
            pretrained=True,          
        uncertainty_weighting=True
    )
    model = model.to(device)
    
    # Optimizer with weight decay for regularization
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, 
                           weight_decay=1e-4)
    
    # Learning rate scheduler with warmup
    num_warmup_steps = len(train_loader) * 5  # 5 epochs warmup
    num_training_steps = len(train_loader) * num_epochs
    
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    # Early stopping
    early_stopping = EarlyStopping(patience=patience, min_delta=0.001)
    
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        epoch_train_loss = 0
        train_task_losses = {
            'adas11_future': 0, 'adas13_future': 0, 'mmscore_future': 0, 'cdglobal_future': 0
        }
        
        train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        for mri_batch, demo_batch, targets_batch in train_loop:
            mri_batch = mri_batch.to(device)
            demo_batch = demo_batch.to(device)
            targets_batch = {k: v.to(device) for k, v in targets_batch.items()}
            
            optimizer.zero_grad()
            predictions = model(mri_batch, demo_batch)
            losses = loss_fn(predictions, targets_batch)
            
            losses['total_loss'].backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            epoch_train_loss += losses['total_loss'].item()
            
            # Accumulate task losses
            for task in train_task_losses.keys():
                train_task_losses[task] += losses[f'{task}_loss'].item()
                
            train_loop.set_postfix({'total_loss': losses['total_loss'].item()})
        
        # Calculate average task losses
        for task in train_task_losses.keys():
            train_task_losses[task] /= len(train_loader)
        
        # Validation phase
        model.eval()
        epoch_val_loss = 0
        val_task_losses = {
            'adas11_future': 0, 'adas13_future': 0, 'mmscore_future': 0, 'cdglobal_future': 0
        }
        val_predictions = {
            'adas11_future': [], 'adas13_future': [], 'mmscore_future': [], 'cdglobal_future': []
        }
        val_targets = {
            'adas11_future': [], 'adas13_future': [], 'mmscore_future': [], 'cdglobal_future': []
        }
        
        scaler = {
            'adas11_future': 70, 'adas13_future': 85, 'mmscore_future': 30, 'cdglobal_future': 3
        }
        
        val_loop = tqdm(val_loader, desc=f"Validation {epoch+1}/{num_epochs}", unit="batch")
        with torch.no_grad():
            for mri_batch, demo_batch, targets_batch in val_loop:
                mri_batch = mri_batch.to(device)
                demo_batch = demo_batch.to(device)
                targets_batch = {k: v.to(device) for k, v in targets_batch.items()}
                
                predictions = model(mri_batch, demo_batch)
                losses = loss_fn(predictions, targets_batch)
                epoch_val_loss += losses['total_loss'].item()
                
                # Accumulate task losses
                for task in val_task_losses.keys():
                    val_task_losses[task] += losses[f'{task}_loss'].item()
                
                # Collect predictions and targets for metrics
                for task in val_predictions.keys():
                    val_predictions[task].extend(predictions[task].cpu().numpy())
                    val_targets[task].extend(targets_batch[task].cpu().numpy())

                val_loop.set_postfix({'total_loss': losses['total_loss'].item()})
        # Calculate average task losses
        for task in val_task_losses.keys():
            val_task_losses[task] /= len(val_loader)
        
        # Calculate metrics
        avg_train_loss = epoch_train_loss / len(train_loader)
        avg_val_loss = epoch_val_loss / len(val_loader)
        
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        
        # Calculate R² and MAE for each task
        metrics = {}
        for task in val_predictions.keys():
            tar = np.array(val_targets[task])* scaler[task]
            pre = np.array(val_predictions[task]) * scaler[task]
            r2 = r2_score(pre, tar)
            mae = mean_absolute_error(pre, tar)
            metrics[f'{task}_r2'] = r2
            metrics[f'{task}_mae'] = mae
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        print(f'  Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
        print('    🧠 ADAS11 FUTURE:')
        print(f'    - Train Loss: {train_task_losses["adas11_future"]:.4f}, Val Loss: {val_task_losses["adas11_future"]:.4f}')
        print(f'    - R²: {metrics["adas11_future_r2"]:.4f}, MAE: {metrics["adas11_future_mae"]:.4f}')
        print('    🧠 ADAS13 FUTURE:')
        print(f'    - Train Loss: {train_task_losses["adas13_future"]:.4f}, Val Loss: {val_task_losses["adas13_future"]:.4f}')
        print(f'    - R²: {metrics["adas13_future_r2"]:.4f}, MAE: {metrics["adas13_future_mae"]:.4f}')
        print('    🧠 MMSCORE FUTURE:')
        print(f'    - Train Loss: {train_task_losses["mmscore_future"]:.4f}, Val Loss: {val_task_losses["mmscore_future"]:.4f}')
        print(f'    - R²: {metrics["mmscore_future_r2"]:.4f}, MAE: {metrics["mmscore_future_mae"]:.4f}')
        print('    🧠 CDGLOBAL FUTURE:')
        print(f'    - Train Loss: {train_task_losses["cdglobal_future"]:.4f}, Val Loss: {val_task_losses["cdglobal_future"]:.4f}')
        print(f'    - R²: {metrics["cdglobal_future_r2"]:.4f}, MAE: {metrics["cdglobal_future_mae"]:.4f}')
        scheduler.step()
        
        # Save best model
        if avg_val_loss < best_val_loss and avg_val_loss < avg_train_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
                'metrics': metrics
            }, 'best_alzheimer_model.pth')
            print(f'  ✅New best model saved at epoch {epoch+1} with best_val {best_val_loss:.4f}!!!')
        # Early stopping
        if early_stopping(avg_val_loss):
            print(f"‼️Early stopping at epoch {epoch+1}")
            break
    
    return train_losses, val_losses

# Main

### Configuration

In [None]:
# Configuration
config = {
    'image_folder': '/kaggle/input/fsl-brain-dataset/skull_stripped/skull_stripped/T1_biascorr_brain_data',
    'data_csv': '/kaggle/input/csv-new-data/new_data.csv',
    'batch_size': 4,
    'image_size': (96, 96, 96),  # Adjust based on your memory
    'val_split': 0.1,
    'test_split':0.02
}

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
print("Loading data...")
train_loader, val_loader, test_loader = load_alzheimer_data(config)
# print(f"Train: {len(train_loader)}, Val: {len(val_loader)}, Test: {len(test_loader)}")

In [None]:
model, loss_fn = create_alzheimer_model(
    pretrained=True,
    uncertainty_weighting=True
)
model = model.to(device)

### Training

In [None]:
# Train model
print("Training model...")
train_losses, val_losses = train_alzheimer_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=200,
    patience=20,
    learning_rate=1e-4,
    # learning_rate=0.001,
    device=device
)

### Result

In [None]:
# Update evaluation function
def evaluate_model(model, test_loader, device='cuda'):
    """Evaluate model on test set"""
    model.eval()
    predictions = {
        'adas11_future': [], 'adas13_future': [], 'mmscore_future': [], 'cdglobal_future': []
    }
    targets = {
        'adas11_future': [], 'adas13_future': [], 'mmscore_future': [], 'cdglobal_future': []
    }
    scaler = {
        'adas11_future': 70, 'adas13_future': 85, 'mmscore_future': 30, 'cdglobal_future': 3
    }
    
    test_loop = tqdm(test_loader, desc="Testing", unit="batch")
    with torch.no_grad():
        for mri_batch, demo_batch, targets_batch in test_loop:
            mri_batch = mri_batch.to(device)
            demo_batch = demo_batch.to(device)    
            
            preds = model(mri_batch, demo_batch)
            
            for task in predictions.keys():
                predictions[task].extend(preds[task].cpu().numpy())
                targets[task].extend(targets_batch[task].numpy())
    
    # Calculate final metrics
    results = {}
    for task in predictions.keys():
        tar = np.array(targets[task]) * scaler[task]
        pre = np.array(predictions[task]) * scaler[task]
        r2 = r2_score(pre, tar)
        mae = mean_absolute_error(pre, tar)
        mse = mean_squared_error(pre, tar)
        
        results[task] = {
            'r2_score': r2,
            'mae': mae,
            'mse': mse,
            'predictions': pre,
            'targets': tar
        }
    
    return results 

In [None]:
# Load best model
print("Loading best model...")
checkpoint = torch.load('/kaggle/input/brain_model/pytorch/default/4/best_alzheimer_model.pth', weights_only=False, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
# # Load best model
# print("Loading best model...")
# checkpoint = torch.load('best_alzheimer_model.pth', weights_only=False)
# model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
test_iter = iter(test_loader)
all_results = []
for i in range(len(test_loader)):
    print(f"Test {i+1}:")
    model.eval()
    mri_test, demographics_test, label_test = next(test_iter)
    mri_test = mri_test.to(device)
    demographics_test = demographics_test.to(device)
    # Dự đoán
    out = model(mri_test, demographics_test)

    # Scale ngược lại các giá trị
    adas11_future_pred = (out["adas11_future"] * 70).cpu().detach().numpy().tolist()
    adas11_future_actual = (label_test["adas11_future"] * 70).cpu().detach().numpy().tolist()

    adas13_future_pred = (out["adas13_future"] * 85).cpu().detach().numpy().tolist()
    adas13_future_actual = (label_test["adas13_future"] * 85).cpu().detach().numpy().tolist()

    mmscore_future_pred = (out["mmscore_future"] * 30).cpu().detach().numpy().tolist()
    mmscore_future_actual = (label_test["mmscore_future"] * 30).cpu().detach().numpy().tolist()

    cdglobal_future_pred = (out["cdglobal_future"] * 3).cpu().detach().numpy().tolist()
    cdglobal_future_actual = (label_test["cdglobal_future"] * 3).cpu().detach().numpy().tolist()

    print('🧠ADAS11 FUTURE:')
    print(f'🔵Predicted: {adas11_future_pred}\n🔴Actual:    {adas11_future_actual}')
    print('🧠ADAS13 FUTURE:')
    print(f'🔵Predicted: {adas13_future_pred}\n🔴Actual:    {adas13_future_actual}')
    print('🧠MMSCORE FUTURE:')
    print(f'🔵Predicted: {mmscore_future_pred}\n🔴Actual:    {mmscore_future_actual}')
    print('🧠CDGLOBAL FUTURE:')
    print(f'🔵Predicted: {cdglobal_future_pred}\n🔴Actual:    {cdglobal_future_actual}')
    print('=============================================================================================================')
print("Done!")


In [None]:
def plot_fig(data, path):
    import matplotlib.pyplot as plt
    
    test_iter = iter(data)
    all_results = []
    for i in range(len(data)):
        model.eval()
        mri_test, demographics_test, label_test = next(test_iter)
        mri_test = mri_test.to(device)
        demographics_test = demographics_test.to(device)
        # Dự đoán
        out = model(mri_test, demographics_test)
    
        # Scale ngược lại các giá trị
        adas11_future_pred = (out["adas11_future"] * 70).cpu().detach().numpy().tolist()
        adas11_future_actual = (label_test["adas11_future"] * 70).cpu().detach().numpy().tolist()
    
        adas13_future_pred = (out["adas13_future"] * 85).cpu().detach().numpy().tolist()
        adas13_future_actual = (label_test["adas13_future"] * 85).cpu().detach().numpy().tolist()
    
        mmscore_future_pred = (out["mmscore_future"] * 30).cpu().detach().numpy().tolist()
        mmscore_future_actual = (label_test["mmscore_future"] * 30).cpu().detach().numpy().tolist()
    
        cdglobal_future_pred = (out["cdglobal_future"] * 3).cpu().detach().numpy().tolist()
        cdglobal_future_actual = (label_test["cdglobal_future"] * 3).cpu().detach().numpy().tolist()
    
        # Gộp lại thành từng sample 1 dòng
        batch_size = len(cdglobal_future_pred)
        for j in range(batch_size):
            all_results.append({
                "Test": i + 1,
                "ADAS11_FUTURE_pred": round(adas11_future_pred[j], 3),
                "ADAS11_FUTURE_actual": round(adas11_future_actual[j], 3),
                "ADAS13_FUTURE_pred": round(adas13_future_pred[j], 3),
                "ADAS13_FUTURE_actual": round(adas13_future_actual[j], 3),
                "MMSCORE_FUTURE_pred": round(mmscore_future_pred[j], 3),
                "MMSCORE_FUTURE_actual": round(mmscore_future_actual[j], 3),
                "CDGLOBAL_FUTURE_pred": round(cdglobal_future_pred[j], 3),
                "CDGLOBAL_FUTURE_actual": round(cdglobal_future_actual[j], 3)
            })
    df_results = pd.DataFrame(all_results)
    plt.figure(figsize=(15, 15))
    # Scatter plot ADAS13
    plt.subplot(2, 2, 1)
    plt.scatter(df_results['ADAS11_FUTURE_actual'], df_results['ADAS11_FUTURE_pred'], alpha=0.6, color='blue', label='ADAS11 FUTURE')
    plt.plot([0, 70], [0, 70], 'r-')  # Đường chéo y = x để so sánh
    plt.xlabel("Actual")
    plt.ylabel("Predict")
    plt.title("Score")
    plt.tight_layout()
    plt.legend()
    plt.subplot(2, 2, 2)
    plt.scatter(df_results['ADAS13_FUTURE_actual'], df_results['ADAS13_FUTURE_pred'], alpha=0.6, color='blue', label='ADAS13 FUTURE')
    plt.plot([0, 85], [0, 85], 'r-')  # Đường chéo y = x để so sánh
    plt.xlabel("Actual")
    plt.ylabel("Predict")
    plt.title("Score")
    plt.tight_layout()
    plt.legend()
    plt.subplot(2, 2, 3)
    plt.scatter(df_results['MMSCORE_FUTURE_actual'], df_results['MMSCORE_FUTURE_pred'], alpha=0.6, color='blue', label='MMSCORE FUTURE')
    plt.plot([-2, 30], [-2, 30], 'r-')  # Đường chéo y = x để so sánh
    plt.xlabel("Actual")
    plt.ylabel("Predict")
    plt.title("Score")
    plt.tight_layout()
    plt.legend()
    plt.subplot(2, 2, 4)
    plt.scatter(df_results['CDGLOBAL_FUTURE_actual'], df_results['CDGLOBAL_FUTURE_pred'], alpha=0.6, color='blue', label='CDGLOBAL FUTURE')
    plt.plot([0, 3], [0, 3], 'r-')  # Đường chéo y = x để so sánh
    plt.xlabel("Actual")
    plt.ylabel("Predict")
    plt.title("Score")
    plt.tight_layout()
    plt.savefig(path, dpi=300, bbox_inches='tight')
    # plt.show()

In [None]:
# Evaluate on test set
def print_test(data):
    print("Evaluating...")
    test_results = evaluate_model(model, data, device)
    
    # Print results
    print("\nTest Results:")
    R2_metric = {
        'adas11_future': 0, 'adas13_future': 0, 'mmscore_future': 0, 'cdglobal_future': 0
    }
    MAE_metric = {
        'adas11_future': 0, 'adas13_future': 0, 'mmscore_future': 0, 'cdglobal_future': 0
    }
    MSE_metric = {
        'adas11_future': 0, 'adas13_future': 0, 'mmscore_future': 0, 'cdglobal_future': 0
    }
    for task, metrics in test_results.items():
        # print(f"{task}:")
        # print(f"  R² Score: {metrics['r2_score']:.4f}")
        # print(f"  MAE: {metrics['mae']:.4f}")
        # print(f"  MSE: {metrics['mse']:.4f}")
        R2_metric[task] = metrics['r2_score']
        MAE_metric[task] = metrics['mae']
        MSE_metric[task] = metrics['mse']
    print("📊 Độ lệch trung bình (MAE):")
    print("🔵Tương lai")
    print(f"  -ADAS11: {MAE_metric['adas11_future']:.4f} điểm")
    print(f"  -ADAS13: {MAE_metric['adas13_future']:.4f} điểm")
    print(f"  -MMSCORE: {MAE_metric['mmscore_future']:.4f} điểm")
    print(f"  -CDGLOBAL: {MAE_metric['cdglobal_future']:.4f} điểm")
    print("📊 Sai số bình phương trung bình (MSE):")
    print("🔵Tương lai")
    print(f"  -ADAS11: {math.sqrt(MSE_metric['adas11_future']):.4f} điểm²")
    print(f"  -ADAS13: {math.sqrt(MSE_metric['adas13_future']):.4f} điểm²")
    print(f"  -MMSCORE: {math.sqrt(MSE_metric['mmscore_future']):.4f} điểm²")
    print(f"  -CDGLOBAL: {math.sqrt(MSE_metric['cdglobal_future']):.4f} điểm²")
    print("📊 Hệ số xác định (R² Score):")
    print("🔵Tương lai")
    print(f"  -ADAS11: {R2_metric['adas11_future']:.4f} điểm")
    print(f"  -ADAS13: {R2_metric['adas13_future']:.4f} điểm")
    print(f"  -MMSCORE: {R2_metric['mmscore_future']:.4f} điểm")
    print(f"  -CDGLOBAL: {R2_metric['cdglobal_future']:.4f} điểm")

**Train**

In [None]:
print_test(train_loader)

In [None]:
plot_fig(train_loader, 'train_plot.png')

**Val**

In [None]:
print_test(val_loader)

In [None]:
plot_fig(val_loader, 'val_plot.png')

**Test**

In [None]:
print_test(test_loader)

In [None]:
plot_fig(test_loader, 'test_plot.png')