# Creating and training models
This file/notebook will create and train different models to see how different models will act.  Once the model has been trained, it will save the models to a folder for future use.

In [1]:
from data_module import get_data_loaders
import os

import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms

## Check for GPU

In [2]:
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple Silicon GPU")
else:
    print("Using CPU")

Using Apple Silicon GPU


## Loading data loaders
Putting the loaders in a dictionary.  Since this file focuses on the models, I really only need the training and validation sets, but since all three loaders are given, it wouldn't hurt to add it to the dictionary.

In [3]:
train_loader, val_loader, test_loader = get_data_loaders(batch_size=32)
dataloaders = {"TRAIN": train_loader, "VAL": val_loader, "TEST": test_loader}

Loaded 5216 images from chest_xray/train
Loaded 782 images from chest_xray/val_new
Loaded 624 images from chest_xray/test

✅ Data loaders created!
   Train: 5216 images
   Val:   782 images
   Test:  624 images


## Define a training function
Followed a tutorial that trains and validates the model at each epoch. With each training, the model, optimizer, criterion, number of epochs and the early stopping critera can be set. 

At the end of each epoch, the statistics is printed to the console, before checking to see if the stopping critera has been met, and if the current model the is being tested is the best model.  If it is, then it is saved to be returned.

In [35]:
def train_model(model, optimizer, criterion, num_epochs=50, early_stopping=None):
    model.train()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        for phase in ["TRAIN", "VAL"]:
            if phase == "TRAIN":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            # running_accuracy = 0.0
            running_correct = 0
            running_total = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == "TRAIN"):
                    outputs = model(inputs)  # Shape: (batch_size, 1), dtype: float
                    outputs = outputs.squeeze(1)  # Shape: (batch_size,), dtype: float
                    loss = criterion(outputs, labels.float())  # ✅ Convert labels to float
                    prediction = torch.sigmoid(outputs).round().long()  # For accuracy calculation
                    _, predicted = torch.max(outputs.data, 1)
                
                # Backward pass and optimization
                if phase == "TRAIN":
                    loss.backward()
                    optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                # running_accuracy += torch.sum(prediction == labels)
                running_total += labels
                running_correct += (predicted == labels).sum().item()
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            # epoch_acc = running_accuracy.double() / len(dataloaders[phase].dataset)
            epoch_acc = (running_correct/running_total) * 100

            print(f'{phase} Loss: {epoch_loss:.2f} Acc: {epoch_acc.item():.2f}')

            # Early stopping
            if phase == "VAL":
                early_stopping(epoch_loss, model)
                if early_stopping.early_stop:
                    print("Early stopping")
                    model.load_state_dict(best_model_wts)
                    return model

            # Save the best model
            if phase == "VAL" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    print(f'Best val Acc: {best_acc:.2f}')
    model.load_state_dict(best_model_wts)
    return model

## Early stopping
This method checks to make sure the training stops in time to prevent overfitting to the training data.

In [6]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

## Generate the different models
Create instances of the different model architectures

In [7]:
criterion = nn.CrossEntropyLoss()

In [12]:
modelDensenetWts = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
num_ftrs = modelDensenetWts.classifier.in_features

for param in modelDensenetWts.parameters():
    param.requires_grad = False

modelDensenetWts.classifier = nn.Linear(num_ftrs, 2)  # Binary classification
optimizerDensenetWts = optim.Adam(modelDensenetWts.classifier.parameters(), lr=0.001)

early_stopDensenetWts = EarlyStopping(patience=7, delta=0.01)
modelDensenetWts = modelDensenetWts.to(device)

In [None]:
trainedDensenetWts = train_model(modelDensenetWts, optimizerDensenetWts, criterion, num_epochs=50, early_stopping=early_stopDensenet)

path = os.path.join("saved_models", "densenet121_weights.pth")
torch

Epoch 1/50
----------
