In [6]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import gc

import sys
sys.path.append('../src')
from models.unet import UNet
from models.enhanced_unet import EnhancedUNet, SpatialAttentionUNet, UltraLightUNet
from utils.losses import BCEDiceLoss
from utils.metrics import dice_coefficient, iou_coefficient
from data.isic_dataset import ISICDataset, load_isic_data


# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [7]:
# Load the ISIC dataset
isic_data_path = '../data/isic_2018_task1_data'
isic_data = load_isic_data(isic_data_path)

# Access the data loaders
isic_train_loader = isic_data['train_loader']
isic_val_loader = isic_data['val_loader']
isic_test_loader = isic_data['test_loader']

# Visualize some samples from the training dataset
print("Visualizing ISIC training samples:")
isic_data['visualize_samples'](isic_data['train_dataset'])

# Visualize validation samples as well
print("Visualizing ISIC validation samples:")
isic_data['visualize_samples'](isic_data['val_dataset'])

# Print dataset statistics
print(f"ISIC Dataset Statistics:")
print(f"Training samples: {isic_data['num_train_samples']}")
print(f"Validation samples: {isic_data['num_val_samples']}")
print(f"Test samples: {isic_data['num_test_samples']}")

Found 2594 image-mask pairs
Train: 1815, Validation: 389, Test: 390
ISIC Dataset Statistics:
Training samples: 1815
Validation samples: 389
Test samples: 390


In [8]:
models = {
    'unet_standard': UNet(n_channels=3, n_classes=1).to(device),
    'unet_with_depthwise': EnhancedUNet(n_channels=3, n_classes=1, use_se=False, use_lightweight=True).to(device),
    'unet_with_se_depthwise': EnhancedUNet(n_channels=3, n_classes=1, use_se=True, use_lightweight=True).to(device),
    'unet_with_se_depthwise_reduced': EnhancedUNet(n_channels=3, n_classes=1, use_se=True, use_lightweight=True, se_reduction=32).to(device),
    'unet_with_spatial_attn': SpatialAttentionUNet(n_channels=3, n_classes=1, use_se=True, use_lightweight=True).to(device),
}

# Print model architecture summary
for model_name, model in models.items():
    print(f"\n{model_name} summary:")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")


unet_standard summary:
Total parameters: 31,037,633

unet_with_depthwise summary:
Total parameters: 5,988,252

unet_with_se_depthwise summary:
Total parameters: 6,206,364

unet_with_se_depthwise_reduced summary:
Total parameters: 6,097,308

unet_with_spatial_attn summary:
Total parameters: 6,206,756


In [9]:
def train_isic_model(model, model_name, train_loader, val_loader, num_epochs=25, 
                     model_dir="../saved_models/isic/", force_train=False, resume_training=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    save_dir = os.path.join(model_dir, model_name)
    best_model_path = os.path.join(save_dir, 'best_model.pth')
    checkpoint_path = os.path.join(save_dir, 'checkpoint.pth')
    
    criterion = BCEDiceLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=5, factor=0.5, min_lr=1e-6)
    
    train_losses = []
    val_losses = []
    val_dices = []
    val_ious = []
    learning_rates = []
    best_val_dice = 0
    start_epoch = 0
    
    if os.path.exists(checkpoint_path) and resume_training and not force_train:
        print(f"Loading checkpoint from {checkpoint_path} to resume training.")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_dice = checkpoint['val_dice']
        
        if 'train_losses' in checkpoint:
            train_losses = checkpoint['train_losses']
            val_losses = checkpoint['val_losses']
            val_dices = checkpoint['val_dices']
            val_ious = checkpoint.get('val_ious', [])
            learning_rates = checkpoint.get('learning_rates', [])
        
        print(f"Resuming from epoch {start_epoch} with best validation Dice: {best_val_dice:.4f}")
    
    # Check if best model exists
    elif os.path.exists(best_model_path) and not force_train:
        print(f"Found existing model at {best_model_path}. Skipping training.")
        model.load_state_dict(torch.load(best_model_path, map_location=device))
        
        # Evaluate the loaded model on validation set
        model.eval()
        epoch_dice = 0
        epoch_iou = 0
        batch_count = 0
        
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                dice = dice_coefficient(outputs, masks)
                iou = iou_coefficient(outputs, masks)
                
                epoch_dice += dice
                epoch_iou += iou
                batch_count += 1
        
        val_dice = epoch_dice / batch_count
        val_iou = epoch_iou / batch_count
        
        print(f"Loaded model performance - Dice: {val_dice:.4f}, IoU: {val_iou:.4f}")
        
        return {
            'model': model,
            'best_val_dice': val_dice,
            'loaded_from_checkpoint': True
        }
    
    # Memory management
    import gc
    gc.collect()
    torch.cuda.empty_cache()
    
    # Training loop
    for epoch in range(start_epoch, num_epochs):
        current_lr = optimizer.param_groups[0]['lr']
        learning_rates.append(current_lr)
        
        # Training phase
        model.train()
        epoch_loss = 0
        batch_count = 0
        
        for images, masks in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping (reduced for dermoscopic images)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.5)
            
            optimizer.step()
            
            epoch_loss += loss.item()
            batch_count += 1
            
            if batch_count % 10 == 0:
                del images, masks, outputs, loss
                torch.cuda.empty_cache()
        
        train_loss = epoch_loss / batch_count
        train_losses.append(train_loss)
        
        # Validation phase
        model.eval()
        epoch_loss = 0
        epoch_dice = 0
        epoch_iou = 0
        batch_count = 0
        
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                dice = dice_coefficient(outputs, masks)
                iou = iou_coefficient(outputs, masks)
                
                epoch_loss += loss.item()
                epoch_dice += dice
                epoch_iou += iou
                batch_count += 1
                
                del images, masks, outputs, loss
        
        val_loss = epoch_loss / batch_count
        val_dice = epoch_dice / batch_count
        val_iou = epoch_iou / batch_count
        val_losses.append(val_loss)
        val_dices.append(val_dice)
        val_ious.append(val_iou)
    
        scheduler.step(val_loss)
        
        print(f'Epoch {epoch+1:3d}/{num_epochs} | LR: {current_lr:.6f} | '
              f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | '
              f'Val Dice: {val_dice:.4f} | Val IoU: {val_iou:.4f}')
        
        # Save checkpoint every epoch for resuming
        os.makedirs(save_dir, exist_ok=True)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_dice': val_dice,
            'val_loss': val_loss,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_dices': val_dices,
            'val_ious': val_ious,
            'learning_rates': learning_rates,
            'best_val_dice': best_val_dice
        }, os.path.join(save_dir, 'checkpoint.pth'))
        
        # Save best model based on Dice coefficient
        if val_dice > best_val_dice:
            best_val_dice = val_dice
            torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth'))
            print(f"Saved new best model with Dice score: {val_dice:.4f}")
        
        # Memory cleanup after each epoch
        gc.collect()
        torch.cuda.empty_cache()
    
    # Save final model
    torch.save(model.state_dict(), os.path.join(save_dir, 'final_model.pth'))
    
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_dices': val_dices,
        'val_ious': val_ious,
        'learning_rates': learning_rates,
        'best_val_dice': best_val_dice,
        'epochs': num_epochs,
        'loaded_from_checkpoint': False
    }

In [10]:
import gc
import torch
from tqdm import tqdm

num_epochs = 25  # Slightly reduced from BUSI as ISIC has more samples
isic_results = {}

# Clear memory before starting training loop
gc.collect()
torch.cuda.empty_cache()

for model_name, model in models.items():
    print(f"\n{'='*20} Training {model_name} model on ISIC {'='*20}")
    try:
        model_results = train_isic_model(
            model, 
            model_name, 
            isic_train_loader, 
            isic_val_loader, 
            num_epochs=num_epochs,
            resume_training=True,
        )
        isic_results[model_name] = model_results
        print(f"Completed training for {model_name} model on ISIC dataset")
        
    except Exception as e:
        print(f"Error training {model_name}: {e}")
        continue
        
    gc.collect()
    torch.cuda.empty_cache()
    
# Save all results to file
import json
import os

os.makedirs('results', exist_ok=True)
with open('results/isic_training_results.json', 'w') as f:
    # Convert non-serializable data (like tensors) to Python types
    serializable_results = {}
    for model_name, results in isic_results.items():
        if isinstance(results, dict):
            serializable_model_results = {}
            for k, v in results.items():
                if isinstance(v, (list, dict, str, int, float, bool)) or v is None:
                    serializable_model_results[k] = v
                elif hasattr(v, 'tolist'):
                    try:
                        serializable_model_results[k] = v.tolist()
                    except:
                        serializable_model_results[k] = str(v)
                else:
                    serializable_model_results[k] = str(v)
            serializable_results[model_name] = serializable_model_results
    
    json.dump(serializable_results, f, indent=2)


Loading checkpoint from ../saved_models/isic/unet_standard/checkpoint.pth to resume training.
Resuming from epoch 25 with best validation Dice: 0.8762
Completed training for unet_standard model on ISIC dataset

Loading checkpoint from ../saved_models/isic/unet_with_depthwise/checkpoint.pth to resume training.
Resuming from epoch 25 with best validation Dice: 0.8869
Completed training for unet_with_depthwise model on ISIC dataset

Loading checkpoint from ../saved_models/isic/unet_with_se_depthwise/checkpoint.pth to resume training.
Resuming from epoch 25 with best validation Dice: 0.8879
Completed training for unet_with_se_depthwise model on ISIC dataset

Loading checkpoint from ../saved_models/isic/unet_with_se_depthwise_reduced/checkpoint.pth to resume training.
Resuming from epoch 25 with best validation Dice: 0.8848
Completed training for unet_with_se_depthwise_reduced model on ISIC dataset

Loading checkpoint from ../saved_models/isic/unet_with_spatial_attn/checkpoint.pth to resum