In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.model_selection import KFold
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
from model_arch import HybridModel
from PIL import Image
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
import ipywidgets as widgets
from IPython.display import display
from torch.utils.data import WeightedRandomSampler
from dataset import KneeDataset



In [2]:
class Config:
    def __init__(self):
        self.img_size = 256  # Reduced image size for memory efficiency
        self.batch_size = 4  # Reduced batch size to avoid memory issues
        self.num_epochs = 20
        self.lr = 3e-4
        self.weight_decay = 1e-4
        self.num_classes = 5  # KL grades: 0-4
        self.num_workers = 4
        self.device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
        
        # Transformer and model parameters
        self.backbone = "resnet50"
        self.transformer_heads = 4  # Reduced number of heads
        self.mlp_ratio = 2  # Reduced feedforward expansion
        self.gradient_clip_val = 1.0
        
        # Cross-validation
        self.num_folds = 5
        self.seed = 42
        
        # Loss weights
        self.seg_weight = 0.5
        self.cls_weight = 0.5
        
        # Paths
        self.train_dir = "/Users/Viku/Datasets/Medical/Knee"  # Update this
        self.val_dir = "/Users/Viku/Datasets/Medical/Knee"  # Update this
        self.output_dir = Path(f'outputs/{datetime.now().strftime("%Y%m%d_%H%M%S")}')
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.model_dir = self.output_dir / 'models'
        self.model_dir.mkdir(exist_ok=True)


In [3]:
class DiceLoss(nn.Module):
    """Dice loss for segmentation with size alignment"""
    def forward(self, pred, target, smooth=1.):
        # Ensure pred and target have the same spatial dimensions
        if pred.shape != target.shape:
            pred = F.interpolate(pred, size=target.shape[2:], mode='bilinear', align_corners=True)
        
        # Flatten predictions and targets
        pred = pred.view(-1)  # Flatten segmentation prediction
        target = target.view(-1)  # Flatten segmentation mask

        # Compute Dice loss
        intersection = (pred * target).sum()
        dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
        return 1 - dice


In [4]:
class CombinedLoss(nn.Module):
    def __init__(self, seg_weight=0.5, cls_weight=0.5, class_weights=None):
        super().__init__()
        self.seg_weight = seg_weight
        self.cls_weight = cls_weight
        self.seg_criterion = DiceLoss()
        self.cls_criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    def forward(self, seg_pred, cls_pred, seg_target, cls_target):
        seg_loss = self.seg_criterion(seg_pred, seg_target)
        cls_loss = self.cls_criterion(cls_pred, cls_target)
        return self.seg_weight * seg_loss + self.cls_weight * cls_loss

In [5]:
class AttentionVisualizer:
    def __init__(self, output_dir):
        self.output_dir = output_dir
        self.attention_dir = os.path.join(output_dir, 'attention_maps')
        os.makedirs(self.attention_dir, exist_ok=True)

    def plot_attention_map(self, image, attention_map, prediction, true_grade, epoch, idx):
        """Visualize attention maps"""
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        
        # Original image
        axes[0].imshow(image.permute(1, 2, 0).cpu().numpy())
        axes[0].set_title(f'Original Image\nPrediction: {prediction}, True Grade: {true_grade}')
        axes[0].axis('off')

        # Attention map overlay
        sns.heatmap(attention_map.mean(0).detach().cpu().numpy(), cmap='viridis', ax=axes[1])
        axes[1].set_title('Attention Map')
        axes[1].axis('off')

        plt.tight_layout()
        save_path = os.path.join(self.attention_dir, f'epoch_{epoch}_sample_{idx}.png')
        plt.savefig(save_path)
        plt.close(fig)


In [6]:
def plot_metrics(history, output_dir):
    """Plot loss and accuracy curves"""
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss plot
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, history['train_loss'], label='Train Loss')
    plt.plot(epochs, history['val_loss'], label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss Curves')
    plt.legend()
    plt.savefig(os.path.join(output_dir, 'loss_curve.png'))
    plt.close()

    # Accuracy plot
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, history['train_acc'], label='Train Accuracy')
    plt.plot(epochs, history['val_acc'], label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Accuracy Curves')
    plt.legend()
    plt.savefig(os.path.join(output_dir, 'accuracy_curve.png'))
    plt.close()


In [7]:
def save_model(model, optimizer, epoch, fold, val_loss, config):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'val_loss': val_loss,
    }
    os.makedirs(config.model_dir, exist_ok=True)
    save_path = os.path.join(config.model_dir, f'best_model_fold_{fold}_epoch_{epoch}.pth')
    torch.save(checkpoint, save_path)
    print(f"Model saved at epoch {epoch} with validation loss {val_loss:.4f}!")


In [8]:
def create_save_button(model, optimizer, epoch, fold, val_loss, config):
    button = widgets.Button(description="Save Model")
    output = widgets.Output()

    def on_button_click(b):
        with output:
            save_model(model, optimizer, epoch, fold, val_loss, config)

    button.on_click(on_button_click)
    display(button, output)

In [9]:
def train_with_cross_validation(config):
    # Check if running in notebook
    is_notebook = True  # You can add proper detection if needed
    num_workers = 0 if is_notebook else config.num_workers
    
    dataset = KneeDataset(config.train_dir, phase='train')
    kfold = KFold(n_splits=config.num_folds, shuffle=True, random_state=config.seed)
    
    # Initialize attention visualizer
    attention_visualizer = AttentionVisualizer(config.output_dir)

    # Calculate class weights
    class_counts = dataset.class_counts
    total_samples = sum(class_counts.values())
    class_weights = torch.FloatTensor([
        total_samples / (count * len(class_counts)) 
        for count in [class_counts[i] for i in range(5)]
    ]).to(config.device)

    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f"Starting Fold {fold + 1}/{config.num_folds}")

        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)

        # Create weighted sampler
        weights = dataset.get_sampling_weights()
        train_weights = weights[train_idx]
        sampler = WeightedRandomSampler(train_weights, len(train_weights))

        # Use num_workers=0 for notebook environment
        train_loader = DataLoader(train_subset, batch_size=config.batch_size, 
                                sampler=sampler, num_workers=num_workers)
        val_loader = DataLoader(val_subset, batch_size=config.batch_size, 
                              shuffle=False, num_workers=num_workers)

        model = HybridModel(num_classes=config.num_classes).to(config.device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, 
                                    weight_decay=config.weight_decay)
        criterion = CombinedLoss(seg_weight=config.seg_weight, 
                               cls_weight=config.cls_weight,
                               class_weights=class_weights)

        best_val_loss = float('inf')
        for epoch in range(config.num_epochs):
            print(f"Epoch [{epoch + 1}/{config.num_epochs}]")

            # Training phase
            model.train()
            train_loss, correct, total = 0, 0, 0

            with tqdm(train_loader, unit="batch") as train_pbar:
                for batch_idx, batch in enumerate(train_pbar):
                    train_pbar.set_description(f"Epoch [{epoch + 1}/{config.num_epochs}]")

                    images = batch['image'].to(config.device)
                    masks = batch['mask'].to(config.device)
                    grades = batch['grade'].to(config.device)

                    optimizer.zero_grad()

                    # Forward pass
                    seg_out, cls_out, attention_maps = model(images, return_attention=True)  # Modified to return attention
                    loss = criterion(seg_out, cls_out, masks, grades)

                    loss.backward()
                    optimizer.step()

                    train_loss += loss.item()
                    _, preds = cls_out.max(1)
                    correct += (preds == grades).sum().item()
                    total += grades.size(0)

                    # Visualize attention maps for first batch of each epoch
                    if batch_idx == 0:
                        for i in range(min(4, images.size(0))):  # Visualize up to 4 images
                            attention_visualizer.plot_attention_map(
                                images[i],
                                attention_maps[i],
                                preds[i].item(),
                                grades[i].item(),
                                epoch,
                                f"train_{i}"
                            )

                    train_pbar.set_postfix({"batch_loss": loss.item()})

            train_acc = correct / total if total > 0 else 0
            avg_train_loss = train_loss / len(train_loader)
            history['train_loss'].append(avg_train_loss)
            history['train_acc'].append(train_acc)

            # Validation phase
            model.eval()
            val_loss, correct, total = 0, 0, 0

            with tqdm(val_loader, unit="batch") as val_pbar:
                for batch_idx, batch in enumerate(val_pbar):
                    val_pbar.set_description(f"Validation [{epoch + 1}/{config.num_epochs}]")

                    images = batch['image'].to(config.device)
                    masks = batch['mask'].to(config.device)
                    grades = batch['grade'].to(config.device)

                    with torch.no_grad():
                        seg_out, cls_out, attention_maps = model(images, return_attention=True)  # Modified to return attention
                        loss = criterion(seg_out, cls_out, masks, grades)

                        val_loss += loss.item()
                        _, preds = cls_out.max(1)
                        correct += (preds == grades).sum().item()
                        total += grades.size(0)

                        # Visualize attention maps for first validation batch
                        if batch_idx == 0:
                            for i in range(min(4, images.size(0))):  # Visualize up to 4 images
                                attention_visualizer.plot_attention_map(
                                    images[i],
                                    attention_maps[i],
                                    preds[i].item(),
                                    grades[i].item(),
                                    epoch,
                                    f"val_{i}"
                                )

                        val_pbar.set_postfix({"batch_loss": loss.item()})

            val_acc = correct / total if total > 0 else 0
            avg_val_loss = val_loss / len(val_loader)
            history['val_loss'].append(avg_val_loss)
            history['val_acc'].append(val_acc)

            print(f"Epoch [{epoch + 1}/{config.num_epochs}] - Train Loss: {avg_train_loss:.4f}, "
                  f"Train Acc: {train_acc:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")

            # Save the best model for this fold
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                save_model(model, optimizer, epoch, fold, best_val_loss, config)

    plot_metrics(history, config.output_dir)

In [10]:
if __name__ == '__main__':
    config = Config()
    train_with_cross_validation(config)

Loaded 5778 samples for phase: train
Class distribution: {0: 2286, 1: 1046, 2: 1516, 3: 757, 4: 173}
Starting Fold 1/5
Epoch [1/20]


Epoch [1/20]:   0%|          | 0/1156 [00:00<?, ?batch/s]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.6784314..0.96862745].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.26274508..0.9843137].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.99215686..-0.34117645].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.6862745..0.77254903].
Epoch [1/20]:   7%|▋         | 84/1156 [02:04<26:27,  1.48s/batch, batch_loss=1.55] 


KeyboardInterrupt: 