# Section 04 — Pretrained CNN Models (Transfer Learning)

This section fine-tunes the following pretrained architectures:
- ResNet50
- DenseNet121
- VGG16

Since CT scans are grayscale, the first convolution layer is modified to accept 1-channel input.

We freeze all backbone layers except the final block to reduce training cost and risk of overfitting.  
The classification head is replaced with a new fully connected layer for four output classes.

Training uses two learning rates:
- A higher learning rate for the classifier head
- A lower learning rate for the unfrozen backbone layers

All models are trained for 15 epochs with best-model checkpointing.

In [None]:
# import necessary libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

In [None]:
IMAGE_SIZE = 224
BATCH_SIZE = 32

# Training transforms (with augmentation)
train_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.10, contrast=0.10),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Validation transforms (no augmentation)
val_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.ImageFolder("data/train", transform=train_transforms)
val_dataset   = datasets.ImageFolder("data/validation", transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

class_names = train_dataset.classes
num_classes = len(class_names)

print("✔ Loaded train + validation datasets.")
print("Classes:", class_names)

✔ Loaded train + validation datasets.
Classes: ['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia']


In [None]:
# Training function for pretrained models with partial fine-tuning
def train_pretrained_model(
    model,
    train_loader,
    val_loader,
    num_epochs=15,
    lr_classifier=0.001,
    lr_backbone=0.0001,
    model_name="Model",
    save_path="model_best.pth"
):

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()

    # Two learning rates:
    optimizer = torch.optim.Adam([
        {"params": model.classifier_params, "lr": lr_classifier},
        {"params": model.backbone_params, "lr": lr_backbone},
    ])

    best_val_accuracy = 0.0

    print(f"\n===== Training {model_name} (Partial Fine-Tuning) =====")

    for epoch in range(1, num_epochs + 1):

        # Training Phase
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)

        # Validation Phase
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                preds = model(images).argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        val_acc = correct / total

        print(f"Epoch {epoch}/{num_epochs} | "
              f"Train Loss: {avg_train_loss:.4f} | "
              f"Val Acc: {val_acc:.4f}")

        if val_acc > best_val_accuracy:
            best_val_accuracy = val_acc
            torch.save(model.state_dict(), save_path)
            print(f"✔ Best model saved → {save_path}")

    print(f"\nTraining completed for {model_name}")
    print(f"Best Validation Accuracy: {best_val_accuracy:.4f}")

    return model

In [None]:
# Building pretrained models with grayscale fix and partial fine-tuning
def build_resnet50(num_classes):
    model = models.resnet50(weights="IMAGENET1K_V1")

    # Fix grayscale 1-channel input
    old_weights = model.conv1.weight.data
    model.conv1 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
    model.conv1.weight.data = old_weights.mean(dim=1, keepdim=True)

    # Freeze everything
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze LAST block (partial fine-tune)
    for param in model.layer4.parameters():
        param.requires_grad = True

    # Replace final FC
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    model.classifier_params = model.fc.parameters()
    model.backbone_params = model.layer4.parameters()

    return model
resnet_model = build_resnet50(num_classes)


def build_densenet121(num_classes):
    model = models.densenet121(weights="IMAGENET1K_V1")

    # Fix 1-channel input
    old_weights = model.features.conv0.weight.data
    model.features.conv0 = nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
    model.features.conv0.weight.data = old_weights.mean(dim=1, keepdim=True)

    # Freeze ALL layers
    for param in model.features.parameters():
        param.requires_grad = False

    # Unfreeze LAST DenseBlock
    for param in model.features.denseblock4.parameters():
        param.requires_grad = True

    # Replace classifier
    model.classifier = nn.Linear(model.classifier.in_features, num_classes)

    model.classifier_params = model.classifier.parameters()
    model.backbone_params = model.features.denseblock4.parameters()

    return model
densenet_model = build_densenet121(num_classes)


def build_vgg16(num_classes):
    model = models.vgg16(weights="IMAGENET1K_V1")

    # Fix grayscale input
    old_weights = model.features[0].weight.data
    model.features[0] = nn.Conv2d(1, 64, 3, padding=1)
    model.features[0].weight.data = old_weights.mean(dim=1, keepdim=True)

    # Freeze ALL conv layers
    for param in model.features.parameters():
        param.requires_grad = False

    # Unfreeze LAST conv block (Block 5)
    for param in model.features[24:].parameters():
        param.requires_grad = True

    # Replace classifier
    in_features = model.classifier[-1].in_features
    model.classifier[-1] = nn.Linear(in_features, num_classes)

    model.classifier_params = model.classifier.parameters()
    model.backbone_params = model.features[24:].parameters()

    return model
vgg_model = build_vgg16(num_classes)

✔ VGG16 ready (Block 5 unfrozen).
✔ DenseNet121 ready (last dense block unfrozen).
✔ ResNet50 ready (last block unfrozen).


In [None]:
# Training the pretrained models with partial fine-tuning
trained_resnet50 = train_pretrained_model(
    model=resnet_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=15,
    lr_classifier=0.001,
    lr_backbone=0.0001,
    model_name="ResNet50",
    save_path="models/resnet50_model.pth"
)

trained_densenet121 = train_pretrained_model(
    model=densenet_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=15,
    lr_classifier=0.001,
    lr_backbone=0.0001,
    model_name="DenseNet121",
    save_path="models/densenet121_model.pth"
)

trained_vgg16 = train_pretrained_model(
    model=vgg_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=15,
    lr_classifier=0.001,
    lr_backbone=0.0001,
    model_name="VGG16",
    save_path="models/vgg16_model.pth"
)


===== Training ResNet50 (Partial Fine-Tuning) =====
Epoch 1/15 | Train Loss: 0.3777 | Val Acc: 0.9172
✔ Best model saved → models/resnet50_model.pth
Epoch 2/15 | Train Loss: 0.2393 | Val Acc: 0.9291
✔ Best model saved → models/resnet50_model.pth
Epoch 3/15 | Train Loss: 0.2118 | Val Acc: 0.9169
Epoch 4/15 | Train Loss: 0.1835 | Val Acc: 0.9317
✔ Best model saved → models/resnet50_model.pth
Epoch 5/15 | Train Loss: 0.1665 | Val Acc: 0.9326
✔ Best model saved → models/resnet50_model.pth
Epoch 6/15 | Train Loss: 0.1565 | Val Acc: 0.9348
✔ Best model saved → models/resnet50_model.pth
Epoch 7/15 | Train Loss: 0.1434 | Val Acc: 0.9235
Epoch 8/15 | Train Loss: 0.1382 | Val Acc: 0.9272
Epoch 9/15 | Train Loss: 0.1333 | Val Acc: 0.9364
✔ Best model saved → models/resnet50_model.pth
Epoch 10/15 | Train Loss: 0.1144 | Val Acc: 0.9357
Epoch 11/15 | Train Loss: 0.1096 | Val Acc: 0.9257
Epoch 12/15 | Train Loss: 0.1098 | Val Acc: 0.9380
✔ Best model saved → models/resnet50_model.pth
Epoch 13/15 | T