# Model Training for Chest X-Ray Classification

This notebook handles the training of the model. It uses the training and validation datasets to train the model and implements early stopping to prevent overfitting.

In [1]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn as nn
import numpy as np

In [2]:
def train_model(model, train_dl, val_dl, epochs=50, patience=15):
    """
    Train the model using the given datasets.
    
    Args:
        model (nn.Module): Model to be trained.
        train_dl (DataLoader): DataLoader for the training dataset.
        val_dl (DataLoader): DataLoader for the validation dataset.
        epochs (int): Number of epochs to train.
        patience (int): Number of epochs to wait for improvement before early stopping.
    
    Returns:
        model (nn.Module): Trained model.
    """
    device = torch.device("cuda")
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_dl:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Validation phase
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for inputs, labels in val_dl:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                val_loss = criterion(outputs, labels)
                running_val_loss += val_loss.item()

        # Check for improvement
        current_val_loss = running_val_loss / len(val_dl)
        print(f"Epoch {epoch+1}, Training Loss: {running_loss / len(train_dl)}")
        print(f"Epoch {epoch+1}, Validation Loss: {current_val_loss}")

        if current_val_loss < best_val_loss:
            best_val_loss = current_val_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
        else:
            epochs_no_improve += 1

        # Early stopping check
        if epochs_no_improve == patience:
            print('Early stopping triggered')
            break

        # Step the scheduler
        scheduler.step(running_val_loss / len(val_dl))
    
    model.load_state_dict(best_model_state)
    return model