In [68]:
from data_module import get_data_loaders

import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms

In [69]:
train_loader, val_loader, test_loader = get_data_loaders(batch_size=32)
dataloaders = {"TRAIN": train_loader, "VAL": val_loader, "TEST": test_loader}

Loaded 5216 images from chest_xray/train
Loaded 782 images from chest_xray/val_new
Loaded 624 images from chest_xray/test

✅ Data loaders created!
   Train: 5216 images
   Val:   782 images
   Test:  624 images


## Define a training function
Focusing on the training loop

In [92]:
def train_model(model, optimizer, criterion, num_epochs=50, early_stopping=None):
    model.train()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        for phase in ["TRAIN", "VAL"]:
            if phase == "TRAIN":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_accuracy = 0.0
            
            for inputs, labels in dataloaders[phase]:
                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == "TRAIN"):
                    outputs = model(inputs)  # Shape: (batch_size, 1), dtype: float
                    outputs = outputs.squeeze(1)  # Shape: (batch_size,), dtype: float
                    loss = criterion(outputs, labels.float())  # ✅ Convert labels to float
                    prediction = torch.sigmoid(outputs).round().long()  # For accuracy calculation
                
                # Backward pass and optimization
                if phase == "TRAIN":
                    loss.backward()
                    optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_accuracy += torch.sum(prediction == labels)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_accuracy.double() / len(dataloaders[phase].dataset)
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Early stopping
            if phase == "VAL":
                early_stopping(epoch_loss, model)
                if early_stopping.early_stop:
                    print("Early stopping")
                    model.load_state_dict(best_model_wts)
                    return model

            # Save the best model
            if phase == "VAL" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    print(f'Best val Acc: {best_acc:.4f}')
    model.load_state_dict(best_model_wts)
    return model

# Early stopping

In [72]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.best_model_state = model.state_dict()
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

    def load_best_model(self, model):
        model.load_state_dict(self.best_model_state)

## CNN 

In [73]:
def cnn_model():
    model = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(64 * 56 * 56, 128),
        nn.ReLU(),
        nn.Linear(128, 1)  # Binary classification
    )
    return model 

## Generate the different models
Create instances of the different model architectures

### ResNet18

In [74]:
criterion = nn.CrossEntropyLoss()

In [None]:
modelResNet18Wts = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_ftrs = modelResNet18Wts.fc.in_features

for param in modelResNet18Wts.parameters():
    param.requires_grad = False

modelResNet18Wts.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet18Wts = optim.SGD(modelResNet18Wts.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)


In [None]:
early_stopResNet18Wts = EarlyStopping(patience=7, delta=0.01)
trainedResNet18Wts = train_model(modelResNet18Wts, optimizerResNet18Wts, criterion, num_epochs=30, early_stopping=early_stop)

Epoch 1/30
----------
TRAIN Loss: 90.0724 Acc: 0.5121
VAL Loss: 127.3932 Acc: 0.2570
Epoch 2/30
----------
TRAIN Loss: 213.8531 Acc: 0.3842
VAL Loss: 384.3915 Acc: 0.2621
Epoch 3/30
----------
TRAIN Loss: 316.6101 Acc: 0.4034
VAL Loss: 312.5717 Acc: 0.3325
Epoch 4/30
----------
TRAIN Loss: 355.9608 Acc: 0.3175
VAL Loss: 364.6425 Acc: 0.5831
Epoch 5/30
----------
TRAIN Loss: 377.0325 Acc: 0.3717
VAL Loss: 435.0505 Acc: 0.7263
Epoch 6/30
----------
TRAIN Loss: 343.5242 Acc: 0.3825
VAL Loss: 363.7200 Acc: 0.6151
Epoch 7/30
----------
TRAIN Loss: 340.4999 Acc: 0.3535
VAL Loss: 361.7333 Acc: 0.2570
Epoch 8/30
----------
TRAIN Loss: 345.9574 Acc: 0.4101
VAL Loss: 416.8527 Acc: 0.2621
Early stopping


### ResNet 18 without default weights

In [None]:
modelResNet18 = models.resnet18()
num_ftrs = modelResNet18.fc.in_features

for param in modelResNet18.parameters():
    param.requires_grad = False

modelResNet18.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet18 = optim.SGD(modelResNet18.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)


In [None]:
early_stopResNet18 = EarlyStopping(patience=7, delta=0.01)
trainedResNet18 = train_model(modelResNet18, optimizerResNet18, criterion, num_epochs=30, early_stopping=early_stopResNet18)

Epoch 1/30
----------
TRAIN Loss: 80.9416 Acc: 0.7354
VAL Loss: 82.3095 Acc: 0.7430
Epoch 2/30
----------
TRAIN Loss: 80.9566 Acc: 0.7465
VAL Loss: 81.8466 Acc: 0.7430
Epoch 3/30
----------
TRAIN Loss: 81.1168 Acc: 0.7538
VAL Loss: 81.8504 Acc: 0.7430
Epoch 4/30
----------
TRAIN Loss: 81.8751 Acc: 0.7521
VAL Loss: 82.3054 Acc: 0.7442
Epoch 5/30
----------
TRAIN Loss: 80.6045 Acc: 0.7465
VAL Loss: 82.7587 Acc: 0.7468
Epoch 6/30
----------
TRAIN Loss: 80.9304 Acc: 0.7454
VAL Loss: 81.8766 Acc: 0.7430
Epoch 7/30
----------
TRAIN Loss: 80.3257 Acc: 0.7441
VAL Loss: 82.6628 Acc: 0.7442
Epoch 8/30
----------
TRAIN Loss: 80.9267 Acc: 0.7471
VAL Loss: 81.9864 Acc: 0.7430
Epoch 9/30
----------
TRAIN Loss: 80.6828 Acc: 0.7481
VAL Loss: 82.4114 Acc: 0.7430
Early stopping


### ResNet50

In [88]:
modelResNet50Wts = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
num_ftrs = modelResNet50Wts.fc.in_features

for param in modelResNet50Wts.parameters():
    param.requires_grad = False

modelResNet50Wts.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet50Wts = optim.SGD(modelResNet50Wts.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /Users/vky/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100.0%


In [None]:
early_stopResNet50Wts = EarlyStopping(patience=7, delta=0.01)
trainedResNet50Wts = train_model(modelResNet50Wts, optimizerResNet50Wts, criterion, num_epochs=30, early_stopping=early_stopResNet50Wts)

Epoch 1/30
----------


In [None]:
modelResNet50 = models.resnet50()
num_ftrs = modelResNet50.fc.in_features

for param in modelResNet50.parameters():
    param.requires_grad = False

modelResNet50.fc = nn.Linear(num_ftrs, 1)  # Binary classification
optimizerResNet50 = optim.SGD(modelResNet50.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)


In [None]:
early_stopResNet50 = EarlyStopping(patience=7, delta=0.01)
trainedResNet50 = train_model(modelResNet50, optimizerResNet50, criterion, num_epochs=30, early_stopping=early_stopResNet50)

### VGG16

In [82]:
modelVGG16Wts = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

for param in modelVGG16Wts.parameters():
    param.requires_grad = False

num_features = modelVGG16Wts.classifier[6].in_features
features = list(modelVGG16Wts.classifier.children())[:-1]
features.extend([nn.Linear(num_features, 1)])  # Binary classification
modelVGG16Wts.classifier = nn.Sequential(*features)

In [83]:
optimizerVGG16Wts = optim.SGD(modelVGG16Wts.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

In [85]:
early_stopVGG16Wts = EarlyStopping(patience=7, delta=0.01)
trainedVGG16Wts = train_model(modelVGG16Wts, optimizerVGG16Wts, criterion, num_epochs=30, early_stopping=early_stopVGG16Wts)

Epoch 1/30
----------
TRAIN Loss: 3704.5884 Acc: 0.3576
VAL Loss: 2298.7189 Acc: 0.2570
Epoch 2/30
----------
TRAIN Loss: 4965.0384 Acc: 0.3466
VAL Loss: 2368.1020 Acc: 0.2570
Epoch 3/30
----------
TRAIN Loss: 5497.6240 Acc: 0.3535
VAL Loss: 2719.3550 Acc: 0.2685
Epoch 4/30
----------
TRAIN Loss: 5353.7078 Acc: 0.3725
VAL Loss: 2267.3925 Acc: 0.6100
Epoch 5/30
----------
TRAIN Loss: 5705.9892 Acc: 0.3357
VAL Loss: 3543.4712 Acc: 0.2852
Epoch 6/30
----------
TRAIN Loss: 5757.1336 Acc: 0.3689
VAL Loss: 3090.7009 Acc: 0.2570
Epoch 7/30
----------
TRAIN Loss: 5800.6663 Acc: 0.2899
VAL Loss: 2749.9224 Acc: 0.2570
Epoch 8/30
----------
TRAIN Loss: 5840.5577 Acc: 0.3043
VAL Loss: 2594.6437 Acc: 0.2570
Epoch 9/30
----------
TRAIN Loss: 5738.3455 Acc: 0.3367
VAL Loss: 3240.1246 Acc: 0.4949
Epoch 10/30
----------
TRAIN Loss: 5785.1744 Acc: 0.3503
VAL Loss: 3096.3356 Acc: 0.2570
Epoch 11/30
----------
TRAIN Loss: 5829.0300 Acc: 0.3583
VAL Loss: 3018.1139 Acc: 0.2711
Early stopping


In [86]:
modelVGG16 = models.vgg16()

for param in modelVGG16.parameters():
    param.requires_grad = False

num_features = modelVGG16.classifier[6].in_features
features = list(modelVGG16.classifier.children())[:-1]
features.extend([nn.Linear(num_features, 1)])  # Binary classification
modelVGG16.classifier = nn.Sequential(*features)

optimizerVGG16 = optim.SGD(modelVGG16.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)

In [87]:
early_stopVGG16 = EarlyStopping(patience=7, delta=0.01)
trainedVGG16 = train_model(modelVGG16, optimizerVGG16, criterion, num_epochs=30, early_stopping=early_stopVGG16)

Epoch 1/30
----------
TRAIN Loss: 82.0456 Acc: 0.2763
VAL Loss: 81.1575 Acc: 0.2634
Epoch 2/30
----------
TRAIN Loss: 82.0039 Acc: 0.2876
VAL Loss: 81.1652 Acc: 0.3350
Epoch 3/30
----------
TRAIN Loss: 81.9872 Acc: 0.3140
VAL Loss: 81.2008 Acc: 0.2647
Epoch 4/30
----------
TRAIN Loss: 82.0858 Acc: 0.3282
VAL Loss: 81.1779 Acc: 0.2941
Epoch 5/30
----------
TRAIN Loss: 82.1179 Acc: 0.3240
VAL Loss: 81.1939 Acc: 0.3043
Epoch 6/30
----------
TRAIN Loss: 81.9140 Acc: 0.3321
VAL Loss: 81.2384 Acc: 0.2570
Epoch 7/30
----------
TRAIN Loss: 81.9239 Acc: 0.3056
VAL Loss: 81.1859 Acc: 0.2621
Epoch 8/30
----------
TRAIN Loss: 81.9642 Acc: 0.3338
VAL Loss: 81.2024 Acc: 0.2762
Early stopping


## Optimizers and loss functions
Potentially different for each model. 