In [None]:
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
import json

# Add src to path
sys.path.append('src')

import config
from models import get_model
from datasets.brain_tumor_dataset import get_dataloaders
from losses.combined_loss import MultiTaskLoss
from utils.metrics import calculate_classification_metrics, calculate_iou, count_parameters

print("✓ Imports successful")
print(f"Device: {config.DEVICE}")
print(f"Model: {config.MODEL_NAME}")

## 2. Configuration (Edit Here)

In [None]:
# Training Configuration (Modify as needed)
MODEL_NAME = 'lightweight_transformer'  # Options: 'unet', 'resnet_unet', 'vit', 'swin', 'lightweight_transformer'
BATCH_SIZE = 8
NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
SAVE_EVERY = 5  # Save checkpoint every N epochs

# Update config
config.MODEL_NAME = MODEL_NAME
config.BATCH_SIZE = BATCH_SIZE
config.NUM_EPOCHS = NUM_EPOCHS
config.LEARNING_RATE = LEARNING_RATE

print(f"Training Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Device: {config.DEVICE}")

## 3. Load Data

In [None]:
print("Loading datasets...")
train_loader, val_loader, test_loader = get_dataloaders(
    train_dir=config.TRAIN_DIR,
    val_dir=config.VAL_DIR,
    test_dir=config.TEST_DIR,
    train_ann=config.TRAIN_ANNOTATIONS,
    val_ann=config.VAL_ANNOTATIONS,
    test_ann=config.TEST_ANNOTATIONS,
    batch_size=config.BATCH_SIZE,
    num_workers=config.NUM_WORKERS,
    image_size=config.IMAGE_SIZE
)

print(f"✓ Data loaded successfully")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

## 4. Initialize Model

In [None]:
print(f"Initializing {MODEL_NAME} model...")
model = get_model(
    model_name=MODEL_NAME,
    n_classes_seg=config.SEGMENTATION_CLASSES,
    n_classes_cls=config.NUM_CLASSES - 1,
    img_size=config.IMAGE_SIZE
).to(config.DEVICE)

num_params = count_parameters(model)
print(f"✓ Model initialized")
print(f"  Parameters: {num_params:,} ({num_params/1e6:.2f}M)")

# Initialize loss, optimizer, scheduler
criterion = MultiTaskLoss(
    classification_weight=config.CLASSIFICATION_WEIGHT,
    segmentation_weight=config.SEGMENTATION_WEIGHT
)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

print("✓ Optimizer and scheduler ready")

## 5. Training Loop with Live Visualization

In [None]:
# Training history
history = {
    'train_loss': [], 'val_loss': [],
    'train_cls_acc': [], 'val_cls_acc': [],
    'train_seg_iou': [], 'val_seg_iou': []
}

best_val_loss = float('inf')
patience_counter = 0

# Create checkpoint directory
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)

print(f"Starting training for {NUM_EPOCHS} epochs...\n")

for epoch in range(1, NUM_EPOCHS + 1):
    # ============= TRAINING =============
    model.train()
    train_loss = 0.0
    all_cls_preds, all_cls_targets = [], []
    all_seg_preds, all_seg_targets = [], []
    
    for batch in train_loader:
        images = batch['image'].to(config.DEVICE)
        seg_masks = batch['segmentation_mask'].to(config.DEVICE)
        cls_labels = batch['classification_label'].to(config.DEVICE)
        
        optimizer.zero_grad()
        cls_output, seg_output = model(images)
        
        loss_dict = criterion(cls_output, seg_output, cls_labels, seg_masks)
        loss = loss_dict['total_loss']
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        all_cls_preds.append(torch.argmax(cls_output, dim=1).cpu())
        all_cls_targets.append(cls_labels.cpu())
        all_seg_preds.append(torch.argmax(seg_output, dim=1).cpu())
        all_seg_targets.append(seg_masks.cpu())
    
    train_loss /= len(train_loader)
    train_cls_acc = calculate_classification_metrics(
        torch.cat(all_cls_preds), torch.cat(all_cls_targets)
    )['accuracy']
    train_seg_iou = calculate_iou(
        torch.cat(all_seg_preds), torch.cat(all_seg_targets), num_classes=2
    )
    
    # ============= VALIDATION =============
    model.eval()
    val_loss = 0.0
    all_cls_preds, all_cls_targets = [], []
    all_seg_preds, all_seg_targets = [], []
    
    with torch.no_grad():
        for batch in val_loader:
            images = batch['image'].to(config.DEVICE)
            seg_masks = batch['segmentation_mask'].to(config.DEVICE)
            cls_labels = batch['classification_label'].to(config.DEVICE)
            
            cls_output, seg_output = model(images)
            loss_dict = criterion(cls_output, seg_output, cls_labels, seg_masks)
            val_loss += loss_dict['total_loss'].item()
            
            all_cls_preds.append(torch.argmax(cls_output, dim=1).cpu())
            all_cls_targets.append(cls_labels.cpu())
            all_seg_preds.append(torch.argmax(seg_output, dim=1).cpu())
            all_seg_targets.append(seg_masks.cpu())
    
    val_loss /= len(val_loader)
    val_cls_acc = calculate_classification_metrics(
        torch.cat(all_cls_preds), torch.cat(all_cls_targets)
    )['accuracy']
    val_seg_iou = calculate_iou(
        torch.cat(all_seg_preds), torch.cat(all_seg_targets), num_classes=2
    )
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_cls_acc'].append(train_cls_acc)
    history['val_cls_acc'].append(val_cls_acc)
    history['train_seg_iou'].append(train_seg_iou)
    history['val_seg_iou'].append(val_seg_iou)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # ============= LIVE PLOTTING =============
    clear_output(wait=True)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Classification Accuracy
    axes[1].plot(history['train_cls_acc'], label='Train Acc', marker='o')
    axes[1].plot(history['val_cls_acc'], label='Val Acc', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Classification Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    # Segmentation IoU
    axes[2].plot(history['train_seg_iou'], label='Train IoU', marker='o')
    axes[2].plot(history['val_seg_iou'], label='Val IoU', marker='s')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Mean IoU')
    axes[2].set_title('Segmentation IoU')
    axes[2].legend()
    axes[2].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Print epoch summary
    print(f"\nEpoch {epoch}/{NUM_EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"  Train Cls Acc: {train_cls_acc:.4f} | Val Cls Acc: {val_cls_acc:.4f}")
    print(f"  Train Seg IoU: {train_seg_iou:.4f} | Val Seg IoU: {val_seg_iou:.4f}")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, f'{config.CHECKPOINT_DIR}/{MODEL_NAME}_best.pth')
        print(f"  ✓ Best model saved (val_loss: {val_loss:.4f})")
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Save periodic checkpoint
    if epoch % SAVE_EVERY == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, f'{config.CHECKPOINT_DIR}/{MODEL_NAME}_epoch_{epoch}.pth')
        print(f"  ✓ Checkpoint saved")
    
    # Early stopping
    if patience_counter >= config.EARLY_STOPPING_PATIENCE:
        print(f"\n⚠ Early stopping triggered after {epoch} epochs")
        break

print("\n✓ Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")

## 6. Save Training History

In [None]:
# Save history
os.makedirs(config.RESULTS_DIR, exist_ok=True)
with open(f'{config.RESULTS_DIR}/{MODEL_NAME}_history.json', 'w') as f:
    json.dump(history, f, indent=4)

print(f"✓ Training history saved to {config.RESULTS_DIR}/{MODEL_NAME}_history.json")

## 7. Quick Test on Sample

In [None]:
# Load best model and test on one sample
checkpoint = torch.load(f'{config.CHECKPOINT_DIR}/{MODEL_NAME}_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Get one test sample
test_batch = next(iter(test_loader))
image = test_batch['image'][:1].to(config.DEVICE)
seg_mask_gt = test_batch['segmentation_mask'][:1].cpu().numpy()[0]
cls_label_gt = test_batch['classification_label'][:1].cpu().numpy()[0]

with torch.no_grad():
    cls_output, seg_output = model(image)
    cls_pred = torch.argmax(cls_output, dim=1).cpu().numpy()[0]
    seg_pred = torch.argmax(seg_output, dim=1).cpu().numpy()[0]

# Denormalize image
img = image[0].cpu().numpy().transpose(1, 2, 0)
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
img = np.clip(img, 0, 1)

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(img)
axes[0].set_title('Input Image')
axes[0].axis('off')

axes[1].imshow(img)
axes[1].imshow(seg_mask_gt, alpha=0.5, cmap='Reds')
axes[1].set_title(f'Ground Truth\nClass: {cls_label_gt}')
axes[1].axis('off')

axes[2].imshow(img)
axes[2].imshow(seg_pred, alpha=0.5, cmap='Reds')
axes[2].set_title(f'Prediction\nClass: {cls_pred}')
axes[2].axis('off')

plt.tight_layout()
plt.show()

print(f"Classification: GT={cls_label_gt}, Pred={cls_pred}, {'✓ Correct' if cls_pred == cls_label_gt else '✗ Wrong'}")