# Steatosis Classification Model Fine-tuning

This notebook implements the fine-tuning process for the DenseNet121 model on the steatosis classification task.

## Setup

First, let's import all necessary dependencies and set up our environment.

In [1]:
%matplotlib inline

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, Optional, Tuple, List
import json
from datetime import datetime
from tqdm.notebook import tqdm
from collections import defaultdict

from model import SteatosisModel, get_loss_fn
from data import create_dataloaders
from evaluation import MetricsCalculator

# import wandb

# Set random seeds for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

## Configuration

Set up training parameters and paths.

In [2]:
# Paths
MODEL_PATH = 'Models/DenseNet121_processed.pt'
DATA_DIR = 'DataSet'
OUTPUT_DIR = 'training_output'

# Training parameters
BATCH_SIZE = 8
BINARY = True  # Set to False for multi-class
NUM_CLASSES = 2 if BINARY else 3
PATIENCE = 5

# Training phases configuration
phases = [
    {'name': 'Classifier Only', 'epochs': 10, 'blocks': None, 'lr': 1e-3},
    {'name': 'Partial Unfreeze', 'epochs': 15, 'blocks': 2, 'lr': 1e-4},
    {'name': 'Full Fine-tuning', 'epochs': 20, 'blocks': None, 'lr': 1e-5}
]

# Create output directory
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
# wandb.login()

# # Initialize wandb run
# wandb.init(
#     project="steatosis-classification",  # Project name
#     name=f"densenet121-{NUM_CLASSES}class-{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}",
#     config={
#         "architecture": "DenseNet121",
#         "dataset": "steatosis",
#         "num_classes": NUM_CLASSES,
#         "batch_size": BATCH_SIZE,
#         "phases": phases,
#         "early_stopping_patience": PATIENCE
#     }
# )

In [4]:
def save_training_metrics(metrics, output_dir, epoch=None, phase_idx=None, phases=None):
    """
    Appends the latest training metrics to a JSON file.
    """
    # Create output directory if it doesn't exist
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    # Use a fixed filename for the metrics log
    metrics_path = f"{output_dir}/training_metrics.json"
    
    # Calculate the latest metrics entry
    latest_metrics = {
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    
    # Add epoch and phase info if provided
    if epoch is not None and phase_idx is not None and phases is not None:
        latest_metrics["epoch"] = epoch + 1
        latest_metrics["phase"] = phase_idx + 1
        latest_metrics["phase_name"] = phases[phase_idx]['name']
        latest_metrics["global_epoch"] = epoch + 1 + sum(p['epochs'] for p in phases[:phase_idx])
    
    
    # Load existing data or create new list
    if Path(metrics_path).exists():
        with open(metrics_path, 'r') as f:
            try:
                all_metrics = json.load(f)
            except json.JSONDecodeError:
                all_metrics = []
    else:
        all_metrics = []
    
    # Append new data
    all_metrics.append(metrics)
    
    # Save back to file
    with open(metrics_path, 'w') as f:
        json.dump(all_metrics, f, indent=2)
    
    return metrics_path

## Data Loading

Create data loaders with the weighted sampling strategy.

In [5]:
try:

    train_exapmles, val_examples = create_dataloaders(
        data_dir=DATA_DIR,
        batch_size=1,
        binary=BINARY
    )

    # Print dataset sizes
    print(f"Training examples: {len(train_exapmles)}")
    print(f"Validation examples: {len(val_examples)}")

    train_loader, val_loader = create_dataloaders(
        data_dir=DATA_DIR,
        batch_size=BATCH_SIZE,
        binary=BINARY
    )
    
    # Print dataset sizes
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    
    # Print sample batch to verify data format
    images, labels = next(iter(train_loader))
    print(f"\nSample batch:")
    print(f"Images shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")
    print(f"Label values: {labels.numpy()}")
    
except Exception as e:
    print(f"Error loading data: {str(e)}")
    raise

Training examples: 7465
Validation examples: 1083
Training batches: 934
Validation batches: 136

Sample batch:
Images shape: torch.Size([8, 3, 224, 224])
Labels shape: torch.Size([8])
Label values: [0 0 0 1 0 1 1 0]


## Model Initialization

Load and modify the pretrained DenseNet121 model.

In [6]:
try:
    # Initialize model
    print("Initializing model...")
    model = SteatosisModel(
        pretrained_path=MODEL_PATH,
        num_classes=NUM_CLASSES,
        freeze_layers=True
    ).to(device)
    
    # Print model structure
    print("\nModel structure:")
    print(model)
    
    # Setup training components
    print("\nSetting up training components...")
    optimizer = model.create_optimizer(
        model.get_trainable_params(),
        optimizer_type='adam',
        lr=1e-3
    )
    scheduler = model.create_scheduler(optimizer, scheduler_type='plateau')
    loss_fn = get_loss_fn(NUM_CLASSES)
    print(f"\nLoss function: {loss_fn}")
    metrics_calculator = MetricsCalculator(device=device)
    
    # Print trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nTrainable parameters: {trainable_params:,}")
    print(f"Total parameters: {total_params:,}")
    
except Exception as e:
    print(f"Error initializing model: {str(e)}")
    raise

Initializing model...
Successfully loaded pretrained weights

Model structure:
SteatosisModel(
  (model): DenseNet(
    (features): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (d

## Training Functions

Define helper functions for training and validation.

In [7]:
def train_epoch(model, train_loader, optimizer, loss_fn, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    
    with tqdm(train_loader, desc='Training') as pbar:
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            # print(f"Data shape: {data.shape}, target shape: {target.shape}")
            
            optimizer.zero_grad()
            output = model(data)
            
            if output.dim() <= 1:  # Scalar or 1D tensor (handles both cases)
                output = output.view(data.size(0))  # Force batch dimension to match
                target = target.float()
            elif output.shape[1] == 1:  # Regular 2D output for binary classification
                output = output.squeeze(1)  # Only squeeze feature dimension
                target = target.float()
            
            # print(f"Output shape: {output.shape}, target shape: {target.shape}")
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    return total_loss / len(train_loader)

def validate(model, val_loader, loss_fn, metrics_calculator, device):
    """Validate model and compute metrics."""
    model.eval()
    total_loss = 0.0
    all_outputs = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in tqdm(val_loader, desc='Validation'):
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            if output.shape[1] == 1:  # Binary classification
                output = output.squeeze()
                target = target.float()

            # print(output.shape, target.shape)    
            loss = loss_fn(output, target)
            total_loss += loss.item()
            
            all_outputs.append(output)
            all_targets.append(target)
    
    # Concatenate all predictions and targets
    outputs = torch.cat(all_outputs)
    targets = torch.cat(all_targets)
    
    is_binary = len(outputs.shape) == 1 or (len(outputs.shape) > 1 and outputs.shape[1] == 1)
    
    # Apply activation function for predictions
    if is_binary:
        outputs = torch.sigmoid(outputs)
    else:
        outputs = torch.softmax(outputs, dim=1)
    
    # print(f"Outputs shape: {outputs.shape}, targets shape: {targets.shape}")
    
    # Calculate metrics
    metrics = metrics_calculator.compute_basic_metrics(targets, outputs)
    
    # Explicitly specify multi_class based on shape, not by checking shape[1]
    metrics.update(
        metrics_calculator.compute_roc_auc(
            targets,
            outputs,
            multi_class=(not is_binary)
        )
    )
    
    return total_loss / len(val_loader), metrics

## Training Loop

Implement the phased fine-tuning process.

In [8]:
# Training history
history = defaultdict(list)

# Best model tracking
best_metric = 0.0
patience_counter = 0
early_stopping_patience = PATIENCE

try:
    # Training loop
    for phase_idx, phase in enumerate(phases):
        print(f"\nStarting {phase['name']} (Phase {phase_idx + 1})")
        print(f"Learning rate: {phase['lr']}")
        
        # Update model for this phase
        model.unfreeze_layers(phase['blocks'])
        for param_group in optimizer.param_groups:
            param_group['lr'] = phase['lr']
        
        for epoch in range(phase['epochs']):
            print(f"\nEpoch {epoch + 1}/{phase['epochs']}")
            
            # Train and validate
            train_loss = train_epoch(model, train_loader, optimizer, loss_fn, device)
            val_loss, metrics = validate(model, val_loader, loss_fn, metrics_calculator, device)
            
            # Update learning rate scheduler
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(metrics['f1_score'])
            else:
                scheduler.step()
            
            # Log metrics to wandb
            # wandb.log({
            #     "train_loss": train_loss,
            #     "val_loss": val_loss,
            #     "f1_score": metrics['f1_score'],
            #     "roc_auc": metrics['roc_auc'],
            #     "sensitivity": metrics['sensitivity'],
            #     "specificity": metrics['specificity'],
            #     "accuracy": metrics['accuracy'],
            #     "precision": metrics['precision'],
            #     "learning_rate": optimizer.param_groups[0]['lr'],
            #     "epoch": epoch + 1 + sum(p['epochs'] for p in phases[:phase_idx])
            # })
            
            # Print metrics
            print(
                f"Train Loss: {train_loss:.4f}, "
                f"Val Loss: {val_loss:.4f}, "
                f"Val F1: {metrics['f1_score']:.4f}, "
                f"Val AUC: {metrics['roc_auc']:.4f}"
            )

            save_training_metrics(metrics, "./logs", epoch, phase_idx, phases)
            
            # Save best model
            if metrics['f1_score'] > best_metric:
                best_metric = metrics['f1_score']
                model.save(f"{OUTPUT_DIR}/best_model.pt")
                patience_counter = 0
                print(f"New best model saved! F1: {best_metric:.4f}")
            else:
                patience_counter += 1
            
            # Early stopping
            if patience_counter >= early_stopping_patience:
                print(f"Early stopping triggered after {patience_counter} epochs without improvement")
                break
            
        
        # Save phase checkpoint
        model.save(f"{OUTPUT_DIR}/phase_{phase_idx + 1}_model.pt")
        print(f"Saved phase {phase_idx + 1} checkpoint")
        
except Exception as e:
    print(f"Error during training: {str(e)}")
    raise

print("\nTraining completed!")


Starting Classifier Only (Phase 1)
Learning rate: 0.001

Epoch 1/10


Training:   0%|          | 0/934 [00:00<?, ?it/s]

Validation:   0%|          | 0/136 [00:00<?, ?it/s]

Train Loss: 0.6906, Val Loss: 15925.6443, Val F1: 0.6472, Val AUC: 0.5936
New best model saved! F1: 0.6472

Epoch 2/10


Training:   0%|          | 0/934 [00:00<?, ?it/s]

Validation:   0%|          | 0/136 [00:00<?, ?it/s]

Train Loss: 0.6860, Val Loss: 138875.5473, Val F1: 0.4836, Val AUC: 0.6159

Epoch 3/10


Training:   0%|          | 0/934 [00:00<?, ?it/s]

Validation:   0%|          | 0/136 [00:00<?, ?it/s]

Train Loss: 0.6835, Val Loss: 32275.8203, Val F1: 0.6672, Val AUC: 0.6470
New best model saved! F1: 0.6672

Epoch 4/10


Training:   0%|          | 0/934 [00:00<?, ?it/s]

Validation:   0%|          | 0/136 [00:00<?, ?it/s]

Train Loss: 0.6834, Val Loss: 14595.4866, Val F1: 0.7277, Val AUC: 0.6645
New best model saved! F1: 0.7277

Epoch 5/10


Training:   0%|          | 0/934 [00:00<?, ?it/s]

Validation:   0%|          | 0/136 [00:00<?, ?it/s]

Train Loss: 0.6803, Val Loss: 53221.6444, Val F1: 0.7418, Val AUC: 0.6304
New best model saved! F1: 0.7418

Epoch 6/10


Training:   0%|          | 0/934 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Training Visualization

Plot training progress and metrics.

In [None]:
try:
    # Create subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot losses
    epochs = range(1, len(history['train_loss']) + 1)
    ax1.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    ax1.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    # Plot metrics
    ax2.plot(epochs, history['f1_score'], 'g-', label='F1 Score')
    ax2.plot(epochs, history['roc_auc'], 'p-', label='ROC AUC')
    ax2.set_title('Validation Metrics')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Score')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/training_history.png")
    plt.show()
    
except Exception as e:
    print(f"Error plotting training history: {str(e)}")
    raise

## Final Evaluation

Load the best model and compute final metrics.

In [None]:
try:
    # Load best model
    print("Loading best model...")
    best_model = SteatosisModel.load(
        f"{OUTPUT_DIR}/best_model.pt",
        num_classes=NUM_CLASSES,
        device=device
    )

    # Compute final metrics
    print("\nComputing final metrics...")
    _, final_metrics = validate(best_model, val_loader, loss_fn, metrics_calculator, device)

    # Print final results
    print("\nFinal Model Performance:")
    for metric, value in final_metrics.items():
        print(f"{metric}: {value:.4f}")

    # Plot ROC curve
    metrics_calculator.plot_roc_curves(
        val_loader.dataset.labels,
        final_metrics['predictions'],
        save_path=f"{OUTPUT_DIR}/final_roc_curve.png"
    )

    # Save final metrics
    with open(f"{OUTPUT_DIR}/final_metrics.json", 'w') as f:
        json.dump(final_metrics, f, indent=4)

    print(f"\nResults saved to {OUTPUT_DIR}")
    
except Exception as e:
    print(f"Error in final evaluation: {str(e)}")
    raise