# Creating and training models
This file/notebook will create and train different models to see how different models will act.  Once the model has been trained, it will save the models to a folder for future use.

In [1]:
from data_module import get_data_loaders
import os

import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms

## Check for GPU

In [2]:
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple Silicon GPU")
else:
    print("Using CPU")

Using Apple Silicon GPU


In [3]:
device = torch.device("cpu")
print(device)

cpu


## Loading data loaders
Putting the loaders in a dictionary.  Since this file focuses on the models, I really only need the training and validation sets, but since all three loaders are given, it wouldn't hurt to add it to the dictionary.

In [4]:
train_loader, val_loader, test_loader = get_data_loaders(batch_size=32)
dataloaders = {"TRAIN": train_loader, "VAL": val_loader, "TEST": test_loader}

Loaded 5216 images from chest_xray/train
Loaded 782 images from chest_xray/val_new
Loaded 624 images from chest_xray/test

✅ Data loaders created!
   Train: 5216 images
   Val:   782 images
   Test:  624 images


## Define a training function
Followed a tutorial that trains and validates the model at each epoch. With each training, the model, optimizer, criterion, number of epochs and the early stopping critera can be set. 

At the end of each epoch, the statistics is printed to the console, before checking to see if the stopping critera has been met, and if the current model the is being tested is the best model.  If it is, then it is saved to be returned.

In [5]:
def train_model(model, optimizer, criterion, num_epochs=50, early_stopping=None):
    model.train()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        for phase in ["TRAIN", "VAL"]:
            if phase == "TRAIN":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_accuracy = 0.0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == "TRAIN"):
                    outputs = model(inputs)  # Shape: (batch_size, 1), dtype: float
                    outputs = outputs.squeeze(1)  # Shape: (batch_size,), dtype: float
                    loss = criterion(outputs, labels.float())  # ✅ Convert labels to float
                    prediction = torch.sigmoid(outputs).round().long()  # For accuracy calculation
                
                # Backward pass and optimization
                if phase == "TRAIN":
                    loss.backward()
                    optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_accuracy += torch.sum(prediction == labels)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_accuracy.double() / len(dataloaders[phase].dataset)
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Early stopping
            if phase == "VAL":
                early_stopping(epoch_loss, model)
                if early_stopping.early_stop:
                    print("Early stopping")
                    model.load_state_dict(best_model_wts)
                    return model

            # Save the best model
            if phase == "VAL" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    print(f'Best val Acc: {best_acc:.4f}')
    model.load_state_dict(best_model_wts)
    return model

## Early stopping
This method checks to make sure the training stops in time to prevent overfitting to the training data.

In [6]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

## Generate the different models
Create instances of the different model architectures

In [7]:
criterion = nn.CrossEntropyLoss()

## CNN 

In [8]:
def cnn_model():
    model = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(64 * 56 * 56, 128),
        nn.ReLU(),
        nn.Linear(128, 1)  # Binary classification
    )
    return model 

In [None]:
modelCNN = cnn_model().to(device)
optimizerCNN = optim.SGD(cnn_model().parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)
early_stopCNN = EarlyStopping(patience=7, delta=0.01)


trainedCNN = train_model(modelCNN, optimizerCNN, criterion, num_epochs=30, early_stopping=early_stopCNN)
pathcnn = os.path.join("saved_models", "cnn_model.pth")
torch.save(trainedCNN.state_dict(), pathcnn)

Epoch 1/30
----------
TRAIN Loss: 82.3640 Acc: 0.7421
VAL Loss: 81.0791 Acc: 0.7430
Epoch 2/30
----------
TRAIN Loss: 82.3645 Acc: 0.7427
VAL Loss: 81.0791 Acc: 0.7430
Epoch 3/30
----------
TRAIN Loss: 82.3666 Acc: 0.7408
VAL Loss: 81.0791 Acc: 0.7430
Epoch 4/30
----------
TRAIN Loss: 82.3684 Acc: 0.7410
VAL Loss: 81.0791 Acc: 0.7430
Epoch 5/30
----------
TRAIN Loss: 82.3636 Acc: 0.7425
VAL Loss: 81.0791 Acc: 0.7430
Epoch 6/30
----------
TRAIN Loss: 82.3746 Acc: 0.7433
VAL Loss: 81.0791 Acc: 0.7430
Epoch 7/30
----------


### ResNet18

In [None]:
modelResNet18Wts = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_ftrs = modelResNet18Wts.fc.in_features

for param in modelResNet18Wts.parameters():
    param.requires_grad = False

modelResNet18Wts.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet18Wts = optim.SGD(modelResNet18Wts.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopResNet18Wts = EarlyStopping(patience=7, delta=0.01)
trainedResNet18Wts = train_model(modelResNet18Wts, optimizerResNet18Wts, criterion, num_epochs=30, early_stopping=early_stop)

In [100]:
path = os.path.join("saved_models", "resnet18_weights.pth")
torch.save(trainedResNet18Wts.state_dict(), path)

### ResNet 18 without default weights

In [None]:
modelResNet18 = models.resnet18()
num_ftrs = modelResNet18.fc.in_features

for param in modelResNet18.parameters():
    param.requires_grad = False

modelResNet18.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet18 = optim.SGD(modelResNet18.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopResNet18 = EarlyStopping(patience=7, delta=0.01)
trainedResNet18 = train_model(modelResNet18, optimizerResNet18, criterion, num_epochs=30, early_stopping=early_stopResNet18)

In [101]:
pathResNet18 = os.path.join("saved_models", "resnet18.pth")
torch.save(trainedResNet18.state_dict(), pathResNet18)

### ResNet50

In [88]:
modelResNet50Wts = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
num_ftrs = modelResNet50Wts.fc.in_features

for param in modelResNet50Wts.parameters():
    param.requires_grad = False

modelResNet50Wts.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet50Wts = optim.SGD(modelResNet50Wts.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopResNet50Wts = EarlyStopping(patience=7, delta=0.01)
trainedResNet50Wts = train_model(modelResNet50Wts, optimizerResNet50Wts, criterion, num_epochs=30, early_stopping=early_stopResNet50Wts)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /Users/vky/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100.0%


In [102]:
pathResNet50Wts = os.path.join("saved_models", "resnet50_weights.pth")
torch.save(trainedResNet50Wts.state_dict(), pathResNet50Wts)

In [None]:
modelResNet50 = models.resnet50()
num_ftrs = modelResNet50.fc.in_features

for param in modelResNet50.parameters():
    param.requires_grad = False

modelResNet50.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet50 = optim.SGD(modelResNet50.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopResNet50 = EarlyStopping(patience=7, delta=0.01)
trainedResNet50 = train_model(modelResNet50, optimizerResNet50, criterion, num_epochs=30, early_stopping=early_stopResNet50)

In [103]:
pathresnet50 = os.path.join("saved_models", "resnet50.pth")
torch.save(trainedResNet50.state_dict(), pathresnet50)

### VGG16

In [None]:
modelVGG16Wts = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

for param in modelVGG16Wts.parameters():
    param.requires_grad = False

num_features = modelVGG16Wts.classifier[6].in_features
features = list(modelVGG16Wts.classifier.children())[:-1]
features.extend([nn.Linear(num_features, 1)])  # Binary classification
modelVGG16Wts.classifier = nn.Sequential(*features)

optimizerVGG16Wts = optim.SGD(modelVGG16Wts.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopVGG16Wts = EarlyStopping(patience=7, delta=0.01)
trainedVGG16Wts = train_model(modelVGG16Wts, optimizerVGG16Wts, criterion, num_epochs=30, early_stopping=early_stopVGG16Wts)

In [104]:
pathVGG16Wts = os.path.join("saved_models", "vgg16_weights.pth")
torch.save(trainedVGG16Wts.state_dict(), pathVGG16Wts)

In [None]:
modelVGG16 = models.vgg16()

for param in modelVGG16.parameters():
    param.requires_grad = False

num_features = modelVGG16.classifier[6].in_features
features = list(modelVGG16.classifier.children())[:-1]
features.extend([nn.Linear(num_features, 1)])  # Binary classification
modelVGG16.classifier = nn.Sequential(*features)

optimizerVGG16 = optim.SGD(modelVGG16.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopVGG16 = EarlyStopping(patience=7, delta=0.01)
trainedVGG16 = train_model(modelVGG16, optimizerVGG16, criterion, num_epochs=30, early_stopping=early_stopVGG16)

In [105]:
pathVGG16Wts = os.path.join("saved_models", "vgg16_weights.pth")
torch.save(trainedVGG16Wts.state_dict(), pathVGG16Wts)

## Optimizers and loss functions
Potentially different for each model. 