In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from utils import load_data, get_data_loaders, evaluate, plot_accuracy_and_loss
from utils import gaussian_subtractive_normalization, clahe_green_channel, clahe_gaussian_blur, hist_equalization_median_blur
import os # Added for directory creation

# Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 16
NUM_EPOCHS = 5
NUM_CLASSES = 5

# Preprocessing pipelines
preprocess_pipelines = {
    'gaussian_subtractive': gaussian_subtractive_normalization,
    'clahe_green': clahe_green_channel,
    'clahe_gaussian': clahe_gaussian_blur,
    'hist_eq_median': hist_equalization_median_blur
}

# Model setup ResNet-50
def get_model():
    model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(model.fc.in_features, NUM_CLASSES)
    )
    return model.to(DEVICE)

train_df, val_df, test_df = load_data()

print(f"Balanced training set size: {len(train_df)} images")
print(f"Validation set size: {len(val_df)} images")
print(f"Combined test set size: {len(test_df)} images")

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, preprocess_name):
    train_accs, train_losses = [], []
    val_accs, val_losses = [], []

    # Create directory for saving models if it doesn't exist
    save_dir = 'resnet50_models'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print("-" * 30)

        model.train() # Set model to training mode
        train_loss = 0.0
        train_corrects = 0
        train_total = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0) # Accumulate loss correctly
            _, preds = torch.max(outputs, 1)
            train_corrects += torch.sum(preds == labels.data)
            train_total += labels.size(0)

        train_loss_avg = train_loss / train_total
        train_acc = train_corrects.double() / train_total
        train_accs.append(train_acc.item()) # Store as float
        train_losses.append(train_loss_avg)
        print(f"Train Loss: {train_loss_avg:.4f} | Train Acc: {train_acc:.4f}")

        #Validation phase
        val_acc, val_loss, val_precision, val_recall, val_f1 = evaluate(model, val_loader)
        val_accs.append(val_acc)
        val_losses.append(val_loss)
        print(f"Val Accuracy: {val_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(f"Val Precision: {val_precision:.4f}")
        print(f"Val Recall: {val_recall:.4f}")
        print(f"Val F1 Score: {val_f1:.4f}")

        # Save model weights
        model_save_path = os.path.join(save_dir, f'resnet50_{preprocess_name}_epoch_{epoch+1}.pth')
        torch.save(model.state_dict(), model_save_path)
        print(f"Model saved to {model_save_path}")


    return train_accs, train_losses, val_accs, val_losses

# Train and evaluate for each preprocessing technique
results = {}

for preprocess_name, preprocess_func in preprocess_pipelines.items():
    print(f"\n--- Training with {preprocess_name} preprocessing ---")

    # Create data loaders
    train_loader, val_loader, test_loader, test_dataset = get_data_loaders(
        train_df, val_df, test_df, BATCH_SIZE, preprocess=preprocess_func
    )

    # Initialize model, criterion, and optimizer
    model = get_model()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Train
    train_accs, train_losses, val_accs, val_losses = train_model(
        model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, preprocess_name
    )

    # Plot results for this preprocessing method
    plot_accuracy_and_loss(train_accs, train_losses, val_accs, val_losses)
    plt.suptitle(f'Metrics for {preprocess_name}', y=1.02) # Add title to plot group
    plt.show()

    # Evaluate on the *single combined* test set
    print(f"\n--- Evaluating {preprocess_name} on Combined Test Set ---")
    test_acc, test_loss, test_precision, test_recall, test_f1 = evaluate(model, test_loader)

    print(f"Combined Test - Acc: {test_acc:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1: {test_f1:.4f}")

    # Store results
    results[preprocess_name] = {
        'train_accs': train_accs, 'train_losses': train_losses,
        'val_accs': val_accs, 'val_losses': val_losses,
        'test_acc': test_acc, 'test_loss': test_loss,
        'test_precision': test_precision, 'test_recall': test_recall, 'test_f1': test_f1
    }


print("\n Final Test Results ")
for name, metrics in results.items():
    print(f"{name}: Test Acc: {metrics['test_acc']:.4f}, Test F1: {metrics['test_f1']:.4f}")
