# Steatosis Classification Model Fine-tuning

This notebook implements the fine-tuning process with comprehensive metric tracking.

In [1]:
%matplotlib inline

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import json
from datetime import datetime
from tqdm.notebook import tqdm
from collections import defaultdict
import pandas as pd
import seaborn as sns

from model import SteatosisModel, get_loss_fn
from data import create_dataloaders
from evaluation import MetricsCalculator

# Set random seeds for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

## Configuration

In [2]:
# Paths
MODEL_PATH = 'Models/DenseNet121_processed.pt'
DATA_DIR = 'DataSet'
OUTPUT_DIR = 'training_output'
LOG_DIR = 'logs'

# Training parameters
RUN_NAME = 'DenseNet121_processed-'+ datetime.now().strftime("%Y%m%d-%H%M")
BATCH_SIZE = 16
BINARY = True
NUM_CLASSES = 2 if BINARY else 3

# Create directories
Path(OUTPUT_DIR).mkdir(exist_ok=True)
Path(LOG_DIR).mkdir(exist_ok=True)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


## Data Loading

In [3]:
train_samples, val_samples = create_dataloaders(
    data_dir=DATA_DIR,
    batch_size=1,
    binary=BINARY
)

train_loader, val_loader = create_dataloaders(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    binary=BINARY
)

print(f"Training samples: {len(train_samples)}")
print(f"Validation samples: {len(val_samples)}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

tensor([1.2258, 0.8445])
tensor([1.5000, 0.7500])
tensor([1.2258, 0.8445])
tensor([1.5000, 0.7500])
Training samples: 7465
Validation samples: 1083
Training batches: 467
Validation batches: 68


## Model Setup

In [4]:
# Initialize model and training components
model = SteatosisModel(
    pretrained_path=MODEL_PATH,
    num_classes=NUM_CLASSES,
    freeze_layers=True
).to(device)

optimizer = model.create_optimizer(
    model.get_trainable_params(),
    optimizer_type='adam',
    lr=1e-3
)
scheduler = model.create_scheduler(optimizer, scheduler_type='plateau')
loss_fn = get_loss_fn(NUM_CLASSES)
metrics_calculator = MetricsCalculator(device=device)

Successfully loaded pretrained weights


## Training Functions

In [None]:
def compute_metrics(outputs, targets, metrics_calculator):
    """Compute metrics for a batch or epoch."""
    if outputs.dim() == 1:  # Binary
        outputs = torch.sigmoid(outputs)
    else:  # Multi-class
        outputs = torch.softmax(outputs, dim=1)
    
    metrics = metrics_calculator.compute_basic_metrics(targets, outputs)
    metrics.update(
        metrics_calculator.compute_roc_auc(
            targets,
            outputs,
            multi_class=outputs.dim() > 1
        )
    )
    return metrics

# Reset BatchNorm statistics before validation
def reset_bn_stats(model):
    for m in model.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.reset_running_stats()
            m.momentum = 0.1  # Default momentum

def train_epoch(model, train_loader, optimizer, loss_fn, metrics_calculator, device):
    """Train for one epoch with detailed metric tracking."""
    model.train()
    total_loss = 0.0
    all_outputs = []
    all_targets = []
    
    with tqdm(train_loader, desc='Training') as pbar:
        for batch_idx, (data, target) in enumerate(pbar):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            # print(".......................")
            # print("before modification")
            # print(output.shape)
            # print(output)
            # print(target.shape)
            # print(target)
            # print(".......................")

            
            if output.shape[1] == 1:  # Binary classification
                output = output.squeeze(1)
                target = target.float()
            
            # print(".......................")
            # print("after modification")
            # print(output.shape)
            # print(output)
            # print(target.shape)
            # print(target)
            # print(".......................")
            loss = loss_fn(output, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            all_outputs.append(output.detach())
            all_targets.append(target)
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
            })
    
    # Compute epoch metrics
    outputs = torch.cat(all_outputs)
    targets = torch.cat(all_targets)
    epoch_metrics = compute_metrics(outputs, targets, metrics_calculator)
    epoch_metrics['loss'] = total_loss / len(train_loader)
    
    return epoch_metrics

def validate(model, val_loader, loss_fn, metrics_calculator, device):
    """Validate model and compute metrics."""
    model.eval()
    total_loss = 0.0
    all_outputs = []
    all_targets = []
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(tqdm(val_loader, desc='Validation')):
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            # print(".......................")
            # print("before modification")
            # print(output.shape)
            # print(output)
            # print(target.shape)
            # print(target)
            # print(".......................")
            if output.shape[1] == 1:  # Binary classification
                output = output.squeeze(1)
                target = target.float()
            
            # print(".......................")
            # print("after modification")
            # print(output.shape)
            # print(output)
            # print(target.shape)
            # print(target)
            # print(".......................")
            loss = loss_fn(output, target)
            # print("loss")
            # print(loss)
            total_loss += loss.item()
            
            all_outputs.append(output)
            all_targets.append(target)
        
    
    # Compute epoch metrics
    outputs = torch.cat(all_outputs)
    targets = torch.cat(all_targets)
    epoch_metrics = compute_metrics(outputs, targets, metrics_calculator)
    print("epoch_metrics")
    epoch_metrics['loss'] = total_loss / len(val_loader)
    
    return epoch_metrics

def log_metrics(metrics, phase, epoch, step=None):
    """Log metrics to JSON file."""
    log_entry = {
        'timestamp': datetime.now().isoformat(),
        'phase': phase,
        'epoch': epoch,
        'metrics': metrics
    }
    if step is not None:
        log_entry['step'] = step
    
    log_file = Path(LOG_DIR) / f'training_metrics-{RUN_NAME}.json'
    
    # Load existing logs if any
    if log_file.exists():
        with open(log_file, 'r') as f:
            logs = json.load(f)
    else:
        logs = []
    
    # Append new log entry
    logs.append(log_entry)
    
    # Save updated logs
    with open(log_file, 'w') as f:
        json.dump(logs, f, indent=2)

## Training Loop

In [6]:
# Training phases configuration
phases = [
    {'name': 'Classifier Only', 'epochs': 10, 'blocks': None, 'lr': 1e-3},
    {'name': 'Partial Unfreeze', 'epochs': 15, 'blocks': 2, 'lr': 1e-4},
    {'name': 'Full Fine-tuning', 'epochs': 20, 'blocks': None, 'lr': 1e-5}
]

# Training history
history = defaultdict(list)

# Best model tracking
best_metric = 0.0
patience_counter = 0
early_stopping_patience = 5

try:
    # Training loop
    for phase_idx, phase in enumerate(phases):
        print(f"\nStarting {phase['name']} (Phase {phase_idx + 1})")
        print(f"Learning rate: {phase['lr']}")
        
        # Update model for this phase
        model.unfreeze_layers(phase['blocks'])
        for param_group in optimizer.param_groups:
            param_group['lr'] = phase['lr']
        
        for epoch in range(phase['epochs']):
            print(f"\nEpoch {epoch + 1}/{phase['epochs']}")
            
            # Train
            train_metrics = train_epoch(
                model, train_loader, optimizer, loss_fn, metrics_calculator, device
            )
            
            # Log training metrics
            log_metrics(train_metrics, 'train', epoch)
            
            # Validate
            val_metrics = validate(
                model, val_loader, loss_fn, metrics_calculator, device
            )
            
            # Log validation metrics
            log_metrics(val_metrics, 'val', epoch)
            
            # Update learning rate scheduler
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_metrics['f1_score'])
            else:
                scheduler.step()
            
            # Update history
            for k, v in train_metrics.items():
                history[f'train_{k}'].append(v)
            for k, v in val_metrics.items():
                history[f'val_{k}'].append(v)
            
            # Print metrics
            print(
                f"Train Loss: {train_metrics['loss']:.4f}, "
                f"F1: {train_metrics['f1_score']:.4f}, "
                f"AUC: {train_metrics['roc_auc']:.4f}\n"
                f"Val Loss: {val_metrics['loss']:.4f}, "
                f"F1: {val_metrics['f1_score']:.4f}, "
                f"AUC: {val_metrics['roc_auc']:.4f}"
            )
            
            # Save best model
            if val_metrics['f1_score'] > best_metric:
                best_metric = val_metrics['f1_score']
                model.save(f"{OUTPUT_DIR}/best_model.pt")
                patience_counter = 0
                print(f"New best model saved! F1: {best_metric:.4f}")
            else:
                patience_counter += 1
            
            # Early stopping
            if patience_counter >= early_stopping_patience:
                print(f"Early stopping triggered after {patience_counter} epochs without improvement")
                break
        
        # Save phase checkpoint
        model.save(f"{OUTPUT_DIR}/phase_{phase_idx + 1}_model.pt")
        
    # Save final history
    with open(f"{OUTPUT_DIR}/training_history.json", 'w') as f:
        json.dump(history, f, indent=2)
        
except Exception as e:
    print(f"Error during training: {str(e)}")
    raise


Starting Classifier Only (Phase 1)
Learning rate: 0.001

Epoch 1/10


Training:   0%|          | 0/467 [00:00<?, ?it/s]

Validation:   0%|          | 0/68 [00:00<?, ?it/s]

loss
tensor(5174.7178, device='cuda:0')
loss
tensor(0.6703, device='cuda:0')
loss
tensor(2667.6060, device='cuda:0')
loss
tensor(0.6395, device='cuda:0')
loss
tensor(3200.9885, device='cuda:0')
loss
tensor(2597.7861, device='cuda:0')
loss
tensor(0.6664, device='cuda:0')
loss
tensor(0.6418, device='cuda:0')
loss
tensor(0.6500, device='cuda:0')
loss
tensor(4538.1909, device='cuda:0')
loss
tensor(0.6678, device='cuda:0')
loss
tensor(0.6763, device='cuda:0')
loss
tensor(55297.2969, device='cuda:0')
loss
tensor(0.6941, device='cuda:0')
loss
tensor(755.2360, device='cuda:0')
loss
tensor(0.6566, device='cuda:0')
loss
tensor(0.6305, device='cuda:0')
loss
tensor(0.6685, device='cuda:0')
loss
tensor(0.6298, device='cuda:0')
loss
tensor(0.6762, device='cuda:0')
loss
tensor(79.6710, device='cuda:0')
loss
tensor(0.5862, device='cuda:0')
loss
tensor(767.4988, device='cuda:0')
loss
tensor(0.6375, device='cuda:0')
loss
tensor(0.6781, device='cuda:0')
loss
tensor(0.6396, device='cuda:0')
loss
tensor(0.

Training:   0%|          | 0/467 [00:00<?, ?it/s]

Validation:   0%|          | 0/68 [00:00<?, ?it/s]

loss
tensor(6.8308, device='cuda:0')
loss
tensor(1476.9443, device='cuda:0')
loss
tensor(0.6788, device='cuda:0')
loss
tensor(187.5409, device='cuda:0')
loss
tensor(0.5972, device='cuda:0')
loss
tensor(58305.9570, device='cuda:0')
loss
tensor(28.7566, device='cuda:0')
loss
tensor(1772.7063, device='cuda:0')
loss
tensor(0.5911, device='cuda:0')
loss
tensor(0.6766, device='cuda:0')
loss
tensor(79.7919, device='cuda:0')
loss
tensor(319.4043, device='cuda:0')
loss
tensor(0.6744, device='cuda:0')
loss
tensor(0.6226, device='cuda:0')
loss
tensor(0.6912, device='cuda:0')
loss
tensor(1398.0076, device='cuda:0')
loss
tensor(25116.0859, device='cuda:0')
loss
tensor(0.6606, device='cuda:0')
loss
tensor(17.6409, device='cuda:0')
loss
tensor(0.6282, device='cuda:0')
loss
tensor(343337.5312, device='cuda:0')
loss
tensor(5406.6714, device='cuda:0')
loss
tensor(0.6961, device='cuda:0')
loss
tensor(0.5996, device='cuda:0')
loss
tensor(6838.4707, device='cuda:0')
loss
tensor(251.5798, device='cuda:0')
l

Training:   0%|          | 0/467 [00:00<?, ?it/s]

Validation:   0%|          | 0/68 [00:00<?, ?it/s]

loss
tensor(0.6246, device='cuda:0')
loss
tensor(0.6605, device='cuda:0')
loss
tensor(0.6510, device='cuda:0')
loss
tensor(0.6310, device='cuda:0')
loss
tensor(0.6596, device='cuda:0')
loss
tensor(0.6411, device='cuda:0')
loss
tensor(0.6542, device='cuda:0')
loss
tensor(0.6981, device='cuda:0')
loss
tensor(60.7694, device='cuda:0')
loss
tensor(0.6273, device='cuda:0')
loss
tensor(0.7062, device='cuda:0')
loss
tensor(0.6841, device='cuda:0')
loss
tensor(0.6841, device='cuda:0')
loss
tensor(0.6504, device='cuda:0')
loss
tensor(0.7213, device='cuda:0')
loss
tensor(0.7075, device='cuda:0')
loss
tensor(0.6541, device='cuda:0')
loss
tensor(0.6392, device='cuda:0')
loss
tensor(0.6870, device='cuda:0')
loss
tensor(0.6518, device='cuda:0')
loss
tensor(0.6692, device='cuda:0')
loss
tensor(0.6643, device='cuda:0')
loss
tensor(428.6830, device='cuda:0')
loss
tensor(0.6738, device='cuda:0')
loss
tensor(49.3898, device='cuda:0')
loss
tensor(390.0576, device='cuda:0')
loss
tensor(374151.4375, device=

Training:   0%|          | 0/467 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Training Visualization

In [None]:
def plot_metrics(history):
    """Plot training and validation metrics."""
    metrics = [
        ('loss', 'Loss'),
        ('f1_score', 'F1 Score'),
        ('roc_auc', 'ROC AUC'),
        ('sensitivity', 'Sensitivity'),
        ('specificity', 'Specificity')
    ]
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for idx, (metric, title) in enumerate(metrics):
        if idx < len(axes):
            train_key = f'train_{metric}'
            val_key = f'val_{metric}'
            
            if train_key in history and val_key in history:
                axes[idx].plot(history[train_key], label='Train')
                axes[idx].plot(history[val_key], label='Validation')
                axes[idx].set_title(title)
                axes[idx].set_xlabel('Epoch')
                axes[idx].grid(True)
                axes[idx].legend()
    
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/training_metrics.png")
    plt.show()

# Plot training history
plot_metrics(history)

# Load and display detailed metrics
with open(f"{LOG_DIR}/training_metrics.json", 'r') as f:
    logs = json.load(f)

# Convert to DataFrame for analysis
df = pd.json_normalize(logs)
print("\nFinal training metrics:")
print(df[df['phase'] == 'train'].tail())
print("\nFinal validation metrics:")
print(df[df['phase'] == 'val'].tail())