# Creating and training models
This file/notebook will create and train different models to see how different models will act.  Once the model has been trained, it will save the models to a folder for future use.

In [1]:
from data_module import get_data_loaders
import os

import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms

## Check for GPU

In [2]:
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple Silicon GPU")
else:
    print("Using CPU")

Using Apple Silicon GPU


## Loading data loaders
Putting the loaders in a dictionary.  Since this file focuses on the models, I really only need the training and validation sets, but since all three loaders are given, it wouldn't hurt to add it to the dictionary.

In [3]:
train_loader, val_loader, test_loader = get_data_loaders(batch_size=32)
dataloaders = {"TRAIN": train_loader, "VAL": val_loader, "TEST": test_loader}

Loaded 5216 images from chest_xray/train
Loaded 782 images from chest_xray/val_new
Loaded 624 images from chest_xray/test

✅ Data loaders created!
   Train: 5216 images
   Val:   782 images
   Test:  624 images


## Define a training function
Followed a tutorial that trains and validates the model at each epoch. With each training, the model, optimizer, criterion, number of epochs and the early stopping critera can be set. 

At the end of each epoch, the statistics is printed to the console, before checking to see if the stopping critera has been met, and if the current model the is being tested is the best model.  If it is, then it is saved to be returned.

In [42]:
def train_model(model, optimizer, criterion, num_epochs=50, early_stopping=None):
    model.train()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('-' * 10)
        print(f'Epoch {epoch+1}/{num_epochs}')
        
        
        for phase in ["TRAIN", "VAL"]:
            if phase == "TRAIN":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            # running_accuracy = 0.0
            running_correct = 0
            running_total = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == "TRAIN"):
                    outputs = model(inputs)  # Shape: (batch_size, 1), dtype: float
                    outputs = outputs.squeeze(1)  # Shape: (batch_size,), dtype: float
                    loss = criterion(outputs, labels.float())  # ✅ Convert labels to float
                    prediction = torch.sigmoid(outputs).round().long()  # For accuracy calculation
                    _, predicted = torch.max(outputs.data, 1)
                
                # Backward pass and optimization
                if phase == "TRAIN":
                    loss.backward()
                    optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                # running_accuracy += torch.sum(prediction == labels)
                running_total += labels.size(0)
                running_correct += (predicted == labels).sum().item()
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            # epoch_acc = running_accuracy.double() / len(dataloaders[phase].dataset)
            epoch_acc = (running_correct/running_total) * 100

            print(f'{phase} Loss: {epoch_loss:.2f} Acc: {epoch_acc:.2f}')

            # Early stopping
            if phase == "VAL":
                early_stopping(epoch_loss, model)
                if early_stopping.early_stop:
                    print("Early stopping")
                    model.load_state_dict(best_model_wts)
                    return model

            # Save the best model
            if phase == "VAL" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    print(f'Best val Acc: {best_acc:.2f}')
    model.load_state_dict(best_model_wts)
    return model

## Early stopping
This method checks to make sure the training stops in time to prevent overfitting to the training data.

In [6]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

## Generate the different models
Create instances of the different model architectures

In [7]:
criterion = nn.CrossEntropyLoss()

### CNN

In [None]:
def cnn_model():
    model = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(64 * 56 * 56, 128),
        nn.ReLU(),
        nn.Linear(128, 1)  # Binary classification
    )
    return model 

In [None]:
modelCNN = cnn_model().to(device)
optimizerCNN = optim.SGD(cnn_model().parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopCNN = EarlyStopping(patience=7, delta=0.01)
modelCNN = modelCNN.to(device)

trainedCNN = train_model(modelCNN, optimizerCNN, criterion, num_epochs=30, early_stopping=early_stopCNN)
pathcnn = os.path.join("saved_models", "cnn_model.pth")
torch.save(trainedCNN.state_dict(), pathcnn)

### ResNet18

In [None]:
modelResNet18Wts = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_ftrs = modelResNet18Wts.fc.in_features

for param in modelResNet18Wts.parameters():
    param.requires_grad = False

modelResNet18Wts.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet18Wts = optim.SGD(modelResNet18Wts.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopResNet18Wts = EarlyStopping(patience=7, delta=0.01)
modelResNet18Wts = modelResNet18Wts.to(device)

trainedResNet18Wts = train_model(modelResNet18Wts, optimizerResNet18Wts, criterion, num_epochs=30, early_stopping=early_stop)

path = os.path.join("saved_models", "resnet18_weights.pth")
torch.save(trainedResNet18Wts.state_dict(), path)

In [None]:
modelResNet18 = models.resnet18()
num_ftrs = modelResNet18.fc.in_features

for param in modelResNet18.parameters():
    param.requires_grad = False

modelResNet18.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet18 = optim.SGD(modelResNet18.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopResNet18 = EarlyStopping(patience=7, delta=0.01)
modelResNet18 = modelResNet18.to(device)

trainedResNet18 = train_model(modelResNet18, optimizerResNet18, criterion, num_epochs=30, early_stopping=early_stopResNet18)

pathResNet18 = os.path.join("saved_models", "resnet18.pth")
torch.save(trainedResNet18.state_dict(), pathResNet18)

### ResNet50

In [None]:
modelResNet50Wts = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
num_ftrs = modelResNet50Wts.fc.in_features

for param in modelResNet50Wts.parameters():
    param.requires_grad = False

modelResNet50Wts.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet50Wts = optim.SGD(modelResNet50Wts.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopResNet50Wts = EarlyStopping(patience=7, delta=0.01)
modelResNet50Wts = modelResNet50Wts.to(device)

trainedResNet50Wts = train_model(modelResNet50Wts, optimizerResNet50Wts, criterion, num_epochs=30, early_stopping=early_stopResNet50Wts)

pathResNet50Wts = os.path.join("saved_models", "resnet50_weights.pth")
torch.save(trainedResNet50Wts.state_dict(), pathResNet50Wts)

In [None]:
modelResNet50 = models.resnet50()
num_ftrs = modelResNet50.fc.in_features

for param in modelResNet50.parameters():
    param.requires_grad = False

modelResNet50.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet50 = optim.SGD(modelResNet50.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopResNet50 = EarlyStopping(patience=7, delta=0.01)
modelResNet50 = modelResNet50.to(device)

trainedResNet50 = train_model(modelResNet50, optimizerResNet50, criterion, num_epochs=30, early_stopping=early_stopResNet50)

pathresnet50 = os.path.join("saved_models", "resnet50.pth")
torch.save(trainedResNet50.state_dict(), pathresnet50)

### VGG16

In [None]:
modelVGG16Wts = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

for param in modelVGG16Wts.parameters():
    param.requires_grad = False

num_features = modelVGG16Wts.classifier[6].in_features
features = list(modelVGG16Wts.classifier.children())[:-1]
features.extend([nn.Linear(num_features, 1)])  # Binary classification
modelVGG16Wts.classifier = nn.Sequential(*features)

optimizerVGG16Wts = optim.SGD(modelVGG16Wts.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopVGG16Wts = EarlyStopping(patience=7, delta=0.01)
modelVGG16Wts = modelVGG16Wts.to(device)

trainedVGG16Wts = train_model(modelVGG16Wts, optimizerVGG16Wts, criterion, num_epochs=30, early_stopping=early_stopVGG16Wts)

pathVGG16Wts = os.path.join("saved_models", "vgg16_weights.pth")
torch.save(trainedVGG16Wts.state_dict(), pathVGG16Wts)

In [None]:
modelVGG16 = models.vgg16()

for param in modelVGG16.parameters():
    param.requires_grad = False

num_features = modelVGG16.classifier[6].in_features
features = list(modelVGG16.classifier.children())[:-1]
features.extend([nn.Linear(num_features, 1)])  # Binary classification
modelVGG16.classifier = nn.Sequential(*features)

optimizerVGG16 = optim.SGD(modelVGG16.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

early_stopVGG16 = EarlyStopping(patience=7, delta=0.01)
modelVGG16 = modelVGG16.to(device)

trainedVGG16 = train_model(modelVGG16, optimizerVGG16, criterion, num_epochs=30, early_stopping=early_stopVGG16)

pathVGG16 = os.path.join("saved_models", "vgg16.pth")
torch.save(trainedVGG16.state_dict(), pathVGG16)

### DenseNet

In [41]:
modelDensenetWts = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
num_ftrs = modelDensenetWts.classifier.in_features

for param in modelDensenetWts.parameters():
    param.requires_grad = False

modelDensenetWts.classifier = nn.Linear(num_ftrs, 2)  # Binary classification
optimizerDensenetWts = optim.Adam(modelDensenetWts.classifier.parameters(), lr=0.001)

early_stopDensenetWts = EarlyStopping(patience=7, delta=0.01)
modelDensenetWts = modelDensenetWts.to(device)

In [43]:
trainedDensenetWts = train_model(modelDensenetWts, optimizerDensenetWts, criterion, num_epochs=50, early_stopping=early_stopDensenetWts)

pathDensenetWts = os.path.join("saved_models", "densenet121_weights.pth")
torch.save(trainedDensenetWts.state_dict(), pathDensenetWts)

----------
Epoch 1/50
TRAIN Loss: 0.02 Acc: 92.39
VAL Loss: 0.00 Acc: 97.31
----------
Epoch 2/50
TRAIN Loss: 0.01 Acc: 91.70
VAL Loss: 0.00 Acc: 91.82
----------
Epoch 3/50
TRAIN Loss: 0.01 Acc: 90.45
VAL Loss: 0.00 Acc: 91.43
----------
Epoch 4/50
TRAIN Loss: 0.01 Acc: 93.50
VAL Loss: 0.00 Acc: 91.43
----------
Epoch 5/50
TRAIN Loss: 0.01 Acc: 91.66
VAL Loss: 0.01 Acc: 91.94
----------
Epoch 6/50
TRAIN Loss: 0.02 Acc: 87.98
VAL Loss: 0.01 Acc: 83.12
----------
Epoch 7/50
TRAIN Loss: 0.02 Acc: 90.09
VAL Loss: 0.00 Acc: 95.40
----------
Epoch 8/50
TRAIN Loss: 0.02 Acc: 89.38
VAL Loss: 0.00 Acc: 93.35
Early stopping


In [None]:
modelDensenet = models.densenet121()
num_ftrs = modelDensenet.classifier.in_features

for param in modelDensenet.parameters():
    param.requires_grad = False

modelDensenet.classifier = nn.Linear(num_ftrs, 2)  # Binary classification
optimizerDensenet = optim.Adam(modelDensenet.classifier.parameters(), lr=0.001)

early_stopDensenet = EarlyStopping(patience=7, delta=0.01)
modelDensenet = modelDensenet.to(device)
trainedDensenet = train_model(modelDensenet, optimizerDensenet, criterion, num_epochs=50, early_stopping=early_stopDensenet)

pathDensenet = os.path.join("saved_models", "densenet121.pth")
torch.save(trainedDensenet.state_dict(), pathDensenet)

### Efficient

In [47]:
modelEfficientWts = models.efficientnet_b0(weights= models.EfficientNet_B0_Weights.DEFAULT)
num_ftrs = modelEfficientWts.classifier[1].in_features

for param in modelEfficientWts.parameters():
    param.requires_grad = False

modelEfficientWts.classifier[1] = nn.Linear(num_ftrs, 2)  # Binary classification
optimizerEfficientWts = optim.Adam(modelEfficientWts.classifier.parameters(), lr=0.001)

early_stopEfficientWts = EarlyStopping(patience=7, delta=0.01)
modelEfficientWts = modelEfficientWts.to(device)

trainedEfficientWts = train_model(modelEfficientWts, optimizerEfficientWts, criterion, num_epochs=50, early_stopping=early_stopEfficientWts)

pathEfficientWts = os.path.join("saved_models", "efficientnet_b0_weights.pth")
torch.save(trainedEfficientWts.state_dict(), pathEfficientWts)

4.3%

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /Users/vky/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


100.0%


----------
Epoch 1/50
TRAIN Loss: 0.05 Acc: 80.87
VAL Loss: 0.04 Acc: 77.24
----------
Epoch 2/50
TRAIN Loss: 0.04 Acc: 83.84
VAL Loss: 0.03 Acc: 85.29
----------
Epoch 3/50
TRAIN Loss: 0.03 Acc: 83.34
VAL Loss: 0.03 Acc: 86.45
----------
Epoch 4/50
TRAIN Loss: 0.03 Acc: 90.53
VAL Loss: 0.01 Acc: 92.46
----------
Epoch 5/50
TRAIN Loss: 0.04 Acc: 85.39
VAL Loss: 0.01 Acc: 96.29
----------
Epoch 6/50
TRAIN Loss: 0.02 Acc: 91.18
VAL Loss: 0.03 Acc: 88.75
----------
Epoch 7/50
TRAIN Loss: 0.03 Acc: 87.90
VAL Loss: 0.03 Acc: 77.75
----------
Epoch 8/50
TRAIN Loss: 0.04 Acc: 82.30
VAL Loss: 0.03 Acc: 86.83
----------
Epoch 9/50
TRAIN Loss: 0.02 Acc: 83.84
VAL Loss: 0.05 Acc: 65.09
----------
Epoch 10/50
TRAIN Loss: 0.04 Acc: 80.10
VAL Loss: 0.04 Acc: 81.59
----------
Epoch 11/50
TRAIN Loss: 0.03 Acc: 81.42
VAL Loss: 0.00 Acc: 93.48
Early stopping


In [None]:
modelEfficient = models.efficientnet_b0()
num_ftrs = modelEfficient.classifier[1].in_features

for param in modelEfficient.parameters():
    param.requires_grad = False

modelEfficient.classifier[1] = nn.Linear(num_ftrs, 2)  # Binary classification
optimizerEfficient = optim.Adam(modelEfficient.classifier.parameters(), lr=0.001)

early_stopEfficient = EarlyStopping(patience=7, delta=0.01)
modelEfficient = modelEfficient.to(device)

trainedEfficient = train_model(modelEfficient, optimizerEfficient, criterion, num_epochs=50, early_stopping=early_stopEfficient)

pathEfficient = os.path.join("saved_models", "efficientnet_b0.pth")
torch.save(trainedEfficient.state_dict(), pathEfficient)

----------
Epoch 1/50
TRAIN Loss: 0.07 Acc: 85.51
VAL Loss: 0.03 Acc: 77.75
----------
Epoch 2/50
