# Model Fine-Tuning Experiments: Pulmonary Fibrosis Binary Classification

This notebook allows you to experiment with and compare different fine-tuning strategies for 3D SE-ResNet50 models on lung CT data. You will:
- Train a baseline model
- Train an improved model (e.g., with layer unfreezing, optimizer changes, or data augmentation)
- Evaluate and compare their performance using accuracy, precision, recall, and confusion matrices

In [None]:
# Section 1: Import Required Libraries
import torch
import torch.nn as nn
from monai.networks.nets import SEResNet50
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_score, recall_score
from preprocessing import create_train_test_val
import copy
import pandas as pd
import numpy as np

In [None]:
# Section 2: Load and Preprocess Data
BATCH_SIZE = 8
NUM_WORKERS = 4
TEST_SIZE = 0.25
VAL_SIZE = 0.5
RANDOM_STATE = 42

# Load data using preprocessing utilities
splits = create_train_test_val(
    "dataset", "dataset/labels_binary.csv",
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
    test_size=TEST_SIZE, val_size=VAL_SIZE, random_state=RANDOM_STATE
)
train_loader = splits['train_loader']
val_loader = splits['val_loader']
test_loader = splits['test_loader']

In [None]:
# Section 3: Define and Train Models with Multiple Optimizers and Freezing Strategies

def set_freeze_strategy(model, strategy):
    if strategy == 'none':
        for param in model.parameters():
            param.requires_grad = True
    elif strategy == 'freeze_layer1':
        for name, param in model.named_parameters():
            if name.startswith('layer1'):
                param.requires_grad = False
            else:
                param.requires_grad = True
    elif strategy == 'freeze_all_but_fc':
        for name, param in model.named_parameters():
            if 'fc' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
    else:
        raise ValueError(f"Unknown freezing strategy: {strategy}")

optimizers_config = {
    'AdamW': lambda params, lr: torch.optim.AdamW(params, lr=lr),
    'SGD': lambda params, lr: torch.optim.SGD(params, lr=lr, momentum=0.9),
}
freezing_strategies = ['none', 'freeze_layer1', 'freeze_all_but_fc']
learning_rates = [1e-4, 5e-5]
num_epochs = 10

experiment_results = []
experiment_histories = []

for opt_name, opt_fn in optimizers_config.items():
    for freeze_strategy in freezing_strategies:
        for lr in learning_rates:
            print(f"\n--- Training with {opt_name}, freeze: {freeze_strategy}, lr: {lr} ---")
            model = SEResNet50(spatial_dims=3, in_channels=1, num_classes=2).to(device)
            set_freeze_strategy(model, freeze_strategy)
            optimizer = opt_fn(filter(lambda p: p.requires_grad, model.parameters()), lr)
            criterion = nn.CrossEntropyLoss()
            trained_model, history = train_model(
                model, train_loader, val_loader, criterion, optimizer, device, num_epochs=num_epochs
            )
            metrics = evaluate_model(trained_model, test_loader, criterion, device, ["normal", "not normal"])
            experiment_results.append({
                'optimizer': opt_name,
                'freeze_strategy': freeze_strategy,
                'learning_rate': lr,
                'test_loss': metrics[0],
                'accuracy': metrics[1],
                'precision': metrics[2],
                'recall': metrics[3]
            })
            experiment_histories.append({
                'optimizer': opt_name,
                'freeze_strategy': freeze_strategy,
                'learning_rate': lr,
                'history': history
            })

In [None]:
# Section 4: Evaluate Each Model
# (Evaluation is performed within the experiment loop in Section 3 for each model/strategy.)

In [None]:
# Section 5: Define and Train Improved Model
# Example: Unfreeze all layers and use a lower learning rate with SGD optimizer
improved_model = SEResNet50(spatial_dims=3, in_channels=1, num_classes=2)
improved_model = improved_model.to(device)

# Unfreeze all layers (default), or selectively freeze if desired
for param in improved_model.parameters():
    param.requires_grad = True

# Try a different optimizer and learning rate
improved_optimizer = torch.optim.SGD(improved_model.parameters(), lr=5e-5, momentum=0.9)
improved_criterion = nn.CrossEntropyLoss()

# Train improved model
improved_model, improved_history = train_model(
    improved_model, train_loader, val_loader, improved_criterion, improved_optimizer, device, num_epochs=10
)

In [None]:
# Section 6: Evaluate Improved Model
improved_metrics = evaluate_model(improved_model, test_loader, improved_criterion, device, class_names)

In [None]:
# Section 7: Compare Model Performance

# Create a DataFrame for all experiment metrics
metrics_df = pd.DataFrame(experiment_results)
display(metrics_df)

# Plot accuracy, precision, recall for all experiments
plt.figure(figsize=(14, 6))
for metric in ['accuracy', 'precision', 'recall']:
    plt.plot(metrics_df[metric], label=metric)
plt.title('Test Metrics Across Experiments')
plt.xlabel('Experiment #')
plt.ylabel('Score')
plt.legend()
plt.show()

# Optionally, plot training/validation curves for each experiment
plt.figure(figsize=(14, 6))
for i, exp in enumerate(experiment_histories):
    plt.plot(exp['history']['val_acc'], label=f"{exp['optimizer']}, {exp['freeze_strategy']}, lr={exp['learning_rate']}")
plt.title('Validation Accuracy per Experiment')
plt.xlabel('Epoch')
plt.ylabel('Validation Accuracy (%)')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

## 1. Import Required Libraries
Import all necessary libraries for model training, evaluation, and visualization.

## 2. Load and Preprocess Data
Load the dataset and prepare train, validation, and test data loaders using the provided preprocessing utilities.

## 3. Define and Train Models with Multiple Optimizers and Freezing Strategies
This section will automatically train models using different combinations of optimizers (AdamW, SGD), learning rates, and layer freezing strategies (none, freeze layer1, freeze all but final layer). Results and training histories are collected for each experiment.

## 4. Evaluate Each Model
Each trained model is evaluated on the test set. Metrics including confusion matrix, accuracy, precision, and recall are recorded for each experiment.

## 5. Compare Model Performance
All experiment results are summarized in a table and visualized with plots to help you identify the best fine-tuning strategy.