In [1]:
import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Enable MPS fallback for unsupported operations
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

sys.path.append('..')

from dataset import PreprocessedDataset
from models import ViTClassifier
from utils import (
    MetricsCalculator,
    compute_class_weights,
    Trainer,
    WeightedBCELoss
)
from utils.config_utils import load_config, load_env, get_device, print_config, validate_config
from utils.data_analysis import analyze_dataset

import wandb

SEED = 5252
torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available()}")
print(f"MPS fallback enabled: {os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK', '0')}")


PyTorch version: 2.5.1
CUDA available: False
MPS available: True
MPS fallback enabled: 1


In [2]:
# Load environment variables and config
env_vars = load_env()
config = load_config('../configs/base_config.yaml')
validate_config(config)

print("\nConfiguration:")
print_config(config)

# Get device
device = get_device(config['hardware']['device'])
print(f"\nUsing device: {device}")

Loaded environment variables from .env
✓ Configuration validated successfully

Configuration:
model:
  name: google/vit-base-patch16-224
  pretrained: True
  num_classes: 5
  input_channels: 9
  channel_adaptation: avg_pool
  dropout: 0.1
  backend: huggingface
data:
  data_root: processed_data
  train_split: train
  val_split: val
  test_split: test
  batch_size: 16
  num_workers: 4
  pin_memory: True
training:
  epochs: 1
  learning_rate: 0.0001
  weight_decay: 1e-05
  optimizer: adamw
  scheduler: cosine
  warmup_epochs: 5
  gradient_clip: 1.0
loss:
  type: weighted_bce
  pos_weight_strategy: inverse_freq
metrics:
  track_per_class: True
  metrics_list: ['auc_roc', 'precision', 'recall', 'f1_score', 'accuracy']
  threshold: 0.5
wandb:
  project: Deep Machine Learning Project
  entity: None
  log_interval: 10
  log_model: True
  watch_model: True
checkpoint:
  save_dir: checkpoints
  save_frequency: 5
  save_best: True
  metric_for_best: val_auc_roc_macro
  mode: max
hardware:
  devi

## 2. Load Datasets


In [3]:
# Create datasets
data_root = config['data']['data_root']

train_dataset = PreprocessedDataset(root_dir=os.path.join('..', data_root, 'train'))
val_dataset = PreprocessedDataset(root_dir=os.path.join('..', data_root, 'val'))
test_dataset = PreprocessedDataset(root_dir=os.path.join('..', data_root, 'test'))

print(f"Dataset sizes:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val:   {len(val_dataset)} samples")
print(f"  Test:  {len(test_dataset)} samples")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['data']['batch_size'],
    shuffle=True,
    num_workers=config['data']['num_workers'],
    pin_memory=config['data']['pin_memory']
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['data']['batch_size'],
    shuffle=False,
    num_workers=config['data']['num_workers'],
    pin_memory=config['data']['pin_memory']
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config['data']['batch_size'],
    shuffle=False,
    num_workers=config['data']['num_workers'],
    pin_memory=config['data']['pin_memory']
)

print(f"\nDataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches:   {len(val_loader)}")
print(f"  Test batches:  {len(test_loader)}")


Dataset sizes:
  Train: 1828 samples
  Val:   228 samples
  Test:  229 samples

DataLoaders created:
  Train batches: 115
  Val batches:   15
  Test batches:  15


In [5]:
# Compute class weights from training data
class_weights = compute_class_weights(
    train_dataset,
    strategy=config['loss']['pos_weight_strategy']
)

print(f"\nComputed class weights:")
class_names = ['epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']
for name, weight in zip(class_names, class_weights):
    print(f"  {name:20s}: {weight:.4f}")

# Move to device
class_weights = class_weights.to(device)

Computing class weights from dataset...

Class Distribution:
------------------------------------------------------------
epidural            :    87 pos ( 4.76%),  1741 neg, weight: 20.0115
intraparenchymal    :   449 pos (24.56%),  1379 neg, weight: 3.0713
intraventricular    :   398 pos (21.77%),  1430 neg, weight: 3.5930
subarachnoid        :   515 pos (28.17%),  1313 neg, weight: 2.5495
subdural            :   470 pos (25.71%),  1358 neg, weight: 2.8894
------------------------------------------------------------

Computed class weights:
  epidural            : 20.0115
  intraparenchymal    : 3.0713
  intraventricular    : 3.5930
  subarachnoid        : 2.5495
  subdural            : 2.8894


In [6]:
# Create model
model = ViTClassifier(
    model_name=config['model']['name'],
    num_classes=config['model']['num_classes'],
    pretrained=config['model']['pretrained'],
    input_channels=config['model']['input_channels'],
    channel_adaptation=config['model']['channel_adaptation'],
    dropout=config['model']['dropout'],
    backend=config['model'].get('backend', 'huggingface')
)

model.freeze_backbone()

model = model.to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel: {config['model']['name']}")
print(f"Backend: {config['model'].get('backend', 'huggingface')}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,} (backbone frozen)")

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Model: google/vit-base-patch16-224
Backend: huggingface
Total parameters: 85,802,501
Trainable parameters: 3,845 (backbone frozen)


In [7]:
# Loss function
criterion = WeightedBCELoss(pos_weights=class_weights)
print("Using Weighted Binary Cross Entropy Loss")

# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=config['training']['learning_rate'],
    weight_decay=config['training']['weight_decay']
)

# Learning rate scheduler
if config['training']['scheduler'] == 'cosine':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config['training']['epochs'],
        eta_min=1e-6
    )
    print("Using Cosine Annealing LR Scheduler")
elif config['training']['scheduler'] == 'step':
    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=10,
        gamma=0.1
    )
    print("Using Step LR Scheduler")
else:
    scheduler = None
    print("No LR Scheduler")

print(f"\nOptimizer: AdamW")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Weight decay: {config['training']['weight_decay']}")


Using Weighted Binary Cross Entropy Loss
Using Cosine Annealing LR Scheduler

Optimizer: AdamW
  Learning rate: 0.0001
  Weight decay: 1e-05


In [8]:
# Initialize W&B
USE_WANDB = True 

if USE_WANDB:
    # Login to W&B (if not already logged in)
    if env_vars.get('WANDB_API_KEY'):
        wandb.login(key=env_vars['WANDB_API_KEY'])
    
    # Initialize run
    run = wandb.init(
        project=config['wandb']['project'],
        entity=env_vars.get('WANDB_ENTITY'),
        name=env_vars.get('NOTEBOOK_NAME', 'vit_ich_training'),
        config=config,
        tags=['vit', 'ich', 'multi-label', 'transformer']
    )
    
    # Watch model (log gradients and parameters)
    if config['wandb']['watch_model']:
        wandb.watch(model, log='all', log_freq=100)
    
    print(f"✓ W&B initialized: {wandb.run.name}")
    print(f"  Project: {config['wandb']['project']}")
    print(f"  Run URL: {wandb.run.get_url()}")
else:
    print("W&B disabled")


[34m[1mwandb[0m: Currently logged in as: [33mnoahmv59[0m ([33mdml_project[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/noahmv/.netrc


✓ W&B initialized: vit_ich_training
  Project: Deep Machine Learning Project
  Run URL: https://wandb.ai/dml_project/Deep%20Machine%20Learning%20Project/runs/i6r9ei35


In [9]:
# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    config=config,
    use_wandb=USE_WANDB
)

# Train the model
trainer.fit(epochs=config['training']['epochs'])


Starting Training for 1 epochs



Epoch 1 [Train]: 100%|██████████| 115/115 [01:13<00:00,  1.57it/s, loss=1.05] 
Epoch 1 [Val]: 100%|██████████| 15/15 [01:49<00:00,  7.32s/it, loss=1.42]



Epoch 1/1
  Train Loss: 1.0977 | Train AUC: 0.5410 | Train F1: 0.3074
  Val Loss:   1.1261 | Val AUC:   0.5609 | Val F1:   0.3236
  Saved checkpoint: checkpoints/checkpoint_epoch_1.pt
  ✓ New best model saved: checkpoints/best_model.pt

Training Complete!
Best val_auc_roc_macro: 0.5609



In [None]:
# Load best checkpoint
best_checkpoint_path = Path('../checkpoints/best_model.pt')

if best_checkpoint_path.exists():
    print(f"Loading best model from {best_checkpoint_path}")
    checkpoint = trainer.load_checkpoint(best_checkpoint_path)
    print(f"  Best validation {trainer.best_metric_name}: {checkpoint['best_metric']:.4f}")
else:
    print("No checkpoint found, using current model")

# Evaluate on test set
test_metrics = trainer.test(test_loader)


In [None]:
# Create bar chart of per-class metrics
class_names_plot = ['Epidural', 'Intraparenchymal', 'Intraventricular', 'Subarachnoid', 'Subdural']
class_names_key = ['epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']

# Extract metrics
precisions = [test_metrics[f'precision_{name}'] for name in class_names_key]
recalls = [test_metrics[f'recall_{name}'] for name in class_names_key]
f1_scores = [test_metrics[f'f1_{name}'] for name in class_names_key]
aucs = [test_metrics.get(f'auc_roc_{name}', 0) for name in class_names_key]

# Plot
x = np.arange(len(class_names_plot))
width = 0.2

fig, ax = plt.subplots(figsize=(14, 6))
ax.bar(x - 1.5*width, precisions, width, label='Precision', color='#3498db')
ax.bar(x - 0.5*width, recalls, width, label='Recall', color='#e74c3c')
ax.bar(x + 0.5*width, f1_scores, width, label='F1-Score', color='#2ecc71')
ax.bar(x + 1.5*width, aucs, width, label='AUC-ROC', color='#9b59b6')

ax.set_xlabel('Hemorrhage Type', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Per-Class Performance Metrics', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(class_names_plot, rotation=45, ha='right')
ax.legend()
ax.grid(axis='y', alpha=0.3)
ax.set_ylim([0, 1])

plt.tight_layout()
plt.savefig('../checkpoints/per_class_metrics.png', dpi=150, bbox_inches='tight')
plt.show()bd5ba8e016b3d0ee9edec60637d209cc6308e6f0

# Log to W&B
if USE_WANDB:
    wandb.log({"per_class_metrics": wandb.Image(plt)})


In [None]:
# Save test metrics to JSON
results = {
    'model': config['model']['name'],
    'test_metrics': {k: float(v) if isinstance(v, (np.floating, float)) else v 
                     for k, v in test_metrics.items()},
    'config': config
}

results_path = Path('../checkpoints/test_results.json')
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"✓ Results saved to {results_path}")

# Print summary
print("\n" + "="*80)
print("FINAL RESULTS SUMMARY")
print("="*80)
print(f"\nOverall Performance:")
print(f"  AUC-ROC (Macro):      {test_metrics['auc_roc_macro']:.4f}")
print(f"  AUC-ROC (Weighted):   {test_metrics.get('auc_roc_weighted', 0):.4f}")
print(f"  F1-Score (Macro):     {test_metrics['f1_macro']:.4f}")
print(f"  Exact Match Accuracy: {test_metrics['accuracy_exact']:.4f}")
print(f"  Hamming Accuracy:     {test_metrics['accuracy_hamming']:.4f}")
print("\n" + "="*80 + "\n")

# Finish W&B run
if USE_WANDB:
    wandb.finish()
    print("✓ W&B run finished")
