## Convolutional Neural Network Models

Convoluational neural networks (CNNs) have long been the foundation of computer vision. Although surpassed by transformer-based models in large-scale and data-rich settings, they remain competitibe for efficiency and resource-constrained applications. We implemented several CNN models, which include AlexNet, VGG-16, ResNet-18, and ResNet-34. We then compared the performance of different architectures.

### Model Implementation

#### 1. AlexNet

AlexNet consists of five convolutional layers followed by three fully connected layers, using large receptive fields, ReLU activations, overlapping max-pooling, and dropout to improve training efficiency and reduce overfitting.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AlexNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(AlexNet, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.fc_layers = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),

            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),

            nn.Linear(4096, num_classes)
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv_layers(x)
        x = torch.flatten(x, 1)
        x = self.fc_layers(x)
        return x

#### 2. VGG

VGG adopts a simple and uniform architecture that stacks multiple small 3×3 convolutional layers with ReLU interleaved with max-pooling layers, followed by three fully connected layers for classification, emphasizing depth and simplicity over large convolutional kernels. VGG has different variants, such as VGG11, VGG13, VGG16, and VGG19, which differ in the number of convolutional layers stacked within each block. This offers a trade-off between model depth and computational cost, while maintaing a uniform overall architecture. We leveraged this characteristic to efficiently implement different VGG variants with one unified class.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from enum import Enum


class VGG(nn.Module):

    class Variant(Enum):
        VGG11 = '11'
        VGG13 = '13'
        VGG16 = '16'
        VGG19 = '19'

    _architectures = {
        Variant.VGG11: [
            64, 'M',
            128, 'M',
            256, 256, 'M',
            512, 512, 'M',
            512, 512, 'M'
        ],
        Variant.VGG13: [
            64, 64, 'M',
            128, 128, 'M',
            256, 256, 'M',
            512, 512, 'M',
            512, 512, 'M'
        ],
        Variant.VGG16: [
            64, 64, 'M',
            128, 128, 'M',
            256, 256, 256, 'M',
            512, 512, 512, 'M',
            512, 512, 512, 'M'
        ],
        Variant.VGG19: [
            64, 64, 'M',
            128, 128, 'M',
            256, 256, 256, 256, 'M',
            512, 512, 512, 512, 'M',
            512, 512, 512, 512, 'M'
        ]
    }

    def __init__(self, in_channels, num_classes, variant='16'):
        super(VGG, self).__init__()
        if not isinstance(variant, self.Variant):
            variant = self.Variant(variant)
        self.conv_layers = self._make_conv_layers(
            in_channels,
            self._architectures[variant]
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes)
        )
        self._init_weights()

    def _make_conv_layers(self, in_channels, arch):
        layers = []
        channels = in_channels
        for x in arch:
            if type(x) == int:
                layers.append(nn.Conv2d(channels, x, kernel_size=3, padding=1))
                layers.append(nn.ReLU(inplace=True))
                channels = x
            elif x == 'M':
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                raise ValueError(f"VGG unknown layer: {x}")
        layers.append(nn.AdaptiveAvgPool2d((7, 7)))
        return nn.Sequential(*layers)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv_layers(x)
        x = torch.flatten(x, 1)
        x = self.fc_layers(x)
        return x

#### 3. ResNet

ResNet introduces residual connections to mitigate the vanishing gradient problem, enabling training very deep networks effectively. Similar to VGG, it is structured with stacks of residual blocks, with popular variants like ResNet-18, ResNet-35, ResNet-50, ResNet-101 ,and ResNet-152, differing both in depth and block structure. We implemented ResNet-18 and ResNet-34 as examples.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from enum import Enum


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    class Variant(Enum):
        RESNET18 = '18'
        RESNET34 = '34'

    _architectures = {
        Variant.RESNET18: (BasicBlock, [2, 2, 2, 2]),
        Variant.RESNET34: (BasicBlock, [3, 4, 6, 3])
    }

    def __init__(self, in_channels, num_classes, variant='18'):
        super(ResNet, self).__init__()
        if not isinstance(variant, self.Variant):
            variant = self.Variant(variant)
        block, layers = self._architectures[variant]
        self.in_channels = 64
        self.conv1 = nn.Conv2d(
            in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        self._init_weights()

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )

        layers = [block(self.in_channels, out_channels, stride, downsample)]
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

### Experimental Evaluation

We used the following scripts to train and test the models.

In [None]:
import os
import os.path as osp
import argparse
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
from time import time
import warnings

from dataset.dataloader import get_multilabel_dataloader

warnings.filterwarnings("ignore", category=UserWarning)


def train_one_epoch(model, loader, criterion, optimizer, max_grad_norm, device, epoch):
    model.train()
    running_loss = 0.0
    start = time()
    for images, targets in loader:
        images = images.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        running_loss += loss.item()
    duration = time() - start
    avg_loss = running_loss / len(loader)
    logging.info(f"epoch {epoch} - loss {avg_loss:.4f} - time {duration:.2f}s")
    return avg_loss


@torch.no_grad()
def validate(model, loader, criterion, device, split):
    model.eval()
    running_loss = 0.0
    for images, targets in loader:
        images = images.to(device)
        targets = targets.to(device)
        outputs = model(images)
        loss = criterion(outputs, targets)
        running_loss += loss.item()
    avg_loss = running_loss / len(loader)
    logging.info(f"{split} loss {avg_loss:.4f}")
    return avg_loss


@torch.no_grad()
def test(model, loader, device, label_dict_path):
    """
    Calculate overall accuracy and recall for the disease class,
    using the 'No Finding' label bit to distinguish normal vs. abnormal.
    """
    model.eval()
    # Load and parse label dictionary into a name→index map
    with open(label_dict_path, "r") as f:
        lines = f.read().splitlines()
    label_to_idx = {}
    for line in lines:
        name, idx = line.split(":")
        label_to_idx[name.strip()] = int(idx.strip())
    # Get index of the 'No Finding' (normal) class
    no_find_idx = label_to_idx["No Finding"]
    # Prepare counters
    total_samples = 0
    correct_preds = 0
    true_positive = 0    # correctly predicted diseased
    false_negative = 0   # diseased samples predicted as normal
    sigmoid = nn.Sigmoid()
    for images, targets in loader:
        images = images.to(device)
        targets = targets.to(device)
        # Model outputs logits for each class
        outputs = model(images)
        probs = sigmoid(outputs)
        # Predicted normal if P(No Finding) >= 0.5
        pred_normal = probs[:, no_find_idx] >= 0.5
        # Ground-truth normal where target bit is 1
        actual_normal = targets[:, no_find_idx] == 1
        # Update accuracy count
        correct_preds += (pred_normal == actual_normal).sum().item()
        total_samples += targets.size(0)
        # For disease (positive) samples (actual_normal == False)
        disease_mask = ~actual_normal
        # Predicted disease if not predicted normal
        pred_disease = ~pred_normal
        true_positive += (disease_mask & pred_disease).sum().item()
        false_negative += (disease_mask & pred_normal).sum().item()
    # Compute metrics
    accuracy = correct_preds / total_samples if total_samples > 0 else 0.0
    recall = (true_positive / (true_positive + false_negative)
              if (true_positive + false_negative) > 0 else 0.0)
    # Log results
    logging.info(f"Test Accuracy: {accuracy * 100:.3f}")
    logging.info(f"Disease Recall: {recall * 100:.3f}")


def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(
        description="Train ViT on multi-label classification")
    parser.add_argument("--img_size", type=int,
                        default=224, help="Input image size")
    parser.add_argument("--ckpt_dir", type=str,
                        default="/home/users/nus/e0945822/scratch/checkpoints", help="Checkpoint directory")
    parser.add_argument("--split_type", type=str, default="balanced", choices=[
                        "balanced", "rare_first", "original", "binary"], help="Split type for dataset")
    parser.add_argument("--save_every", type=int, default=10,
                        help="Save checkpoint every N epochs")
    parser.add_argument("--epochs", type=int, default=50,
                        help="Number of training epochs")
    parser.add_argument("--batch_size", type=int,
                        default=32, help="Batch size")
    parser.add_argument("--device", type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--index", type=int, default=0, help="GPU device ID")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    parser.add_argument("--test_only", action="store_true",
                        help="Test only without training")
    parser.add_argument("--ckpt_path", type=str, default=None,
                        help="Absolute path to the checkpoint for testing")
    parser.add_argument("--model", choices=["alexnet", "vgg16", "resnet18", "resnet34", "vit_pt"],
                        help="The model to run")
    args = parser.parse_args()
    # Set logging configuration
    log_dir = osp.join(
        "logs",
        datetime.now().strftime("%Y-%m-%d"),
    )
    os.makedirs(log_dir, exist_ok=True)
    log_file = osp.join(
        log_dir,
        f"{args.model}" + f"{'-test' if args.test_only else ''}" +
        f"-{args.split_type}-{datetime.now().strftime('%H-%M-%S')}.log"
    )
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[
            logging.FileHandler(log_file, mode='w'),
            logging.StreamHandler()
        ]
    )

    # Get data loaders
    logging.info("Loading data...")
    torch.manual_seed(args.seed)
    if args.img_size == 224:
        data_dir = "/home/users/nus/e0945822/scratch/5242data/data_224/"
        if not os.path.exists(data_dir):
            data_dir = "data_tensor"
        train_loader = get_multilabel_dataloader(
            data_dir,
            split_type=args.split_type,
            split="train",
            batch_size=args.batch_size,
            shuffle=True
        )
        val_loader = get_multilabel_dataloader(
            data_dir,
            split_type=args.split_type,
            split="val",
            batch_size=args.batch_size,
            shuffle=False
        )
        test_loader = get_multilabel_dataloader(
            data_dir,
            split_type=args.split_type,
            split="test",
            batch_size=args.batch_size,
            shuffle=False
        )
    else:
        raise ValueError("Invalid image size. Only supporting 224.")

    # Initialize model, loss function, and optimizer
    logging.info("Initializing model...")
    device = torch.device(args.device)
    if args.device == "cuda":
        device = torch.device(f"cuda:{args.index}")
        torch.cuda.set_device(args.index)

    model_args = {
        "in_channels": 1,
        "num_classes": 15
    }

    if args.model == 'alexnet':
        from models.alexnet import AlexNet
        model = AlexNet(**model_args).to(device)
    elif args.model == 'vgg16':
        from models.vgg import VGG
        model = VGG(**model_args, variant='16').to(device)
    elif args.model == 'resnet18' or 'resnet34':
        from models.resnet import ResNet
        variant = args.model.split('resnet')[0]
        model = ResNet(**model_args, variant=variant).to(device)
    elif args.model == 'vit_pt':
        from models.vit_pt import ViT_PT
        model = ViT_PT(**model_args).to(device)

    if args.ckpt_path:
        model.load_state_dict(torch.load(args.ckpt_path))
        logging.info(f"Loaded model from {args.ckpt_path}")
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-4)
    max_grad_norm = 1.0
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=3,
        verbose=True
    )

    # Main training loop
    if not args.test_only:
        logging.info("Starting training...")
        best_val_loss = float("inf")
        no_improve_epochs = 0
        early_stop_patience = 10
        ckpt_dir = osp.join(args.ckpt_dir, f"{args.model}" +
                            f"-{args.split_type}-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}")
        os.makedirs(ckpt_dir, exist_ok=True)
        for epoch in range(1, args.epochs + 1):
            train_one_epoch(model, train_loader, criterion,
                            optimizer, max_grad_norm, device, epoch)
            # Gradually improve the gradient clipping threshold
            if epoch > 10:
                max_grad_norm = min(max_grad_norm + 0.1, 5.0)

            val_loss = validate(model, val_loader, criterion, device, "val")
            scheduler.step(val_loss)
            # Early stopping if no improvement or learning rate is too low
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                no_improve_epochs = 0
                torch.save(
                    model.state_dict(),
                    osp.join(ckpt_dir, "best_model.pt")
                )
                logging.info(
                    f"Best model saved at epoch {epoch} with val loss {val_loss:.4f}")
            else:
                no_improve_epochs += 1
            current_lr = optimizer.param_groups[0]['lr']
            if no_improve_epochs >= early_stop_patience or current_lr < 1e-6:
                logging.info(f"Early stopping triggered at epoch {epoch}.")
                break

            # Periodic checkpoint saving
            if epoch % args.save_every == 0:
                torch.save(
                    model.state_dict(),
                    osp.join(ckpt_dir, f"checkpoint_epoch{epoch}.pt")
                )
                logging.info(f"Checkpoint saved at epoch {epoch}")

        # Save the final model
        torch.save(
            model.state_dict(),
            osp.join(ckpt_dir, "final_model.pt")
        )
        logging.info("Final model saved.")

    # Testing
    logging.info("Testing...")
    # Load the best model for testing
    if not args.test_only:
        ckpt_path = osp.join(ckpt_dir, "best_model.pt")
    else:
        if args.ckpt_path is None:
            raise ValueError("Please provide a checkpoint path for testing.")
        ckpt_path = args.ckpt_path
    model.load_state_dict(torch.load(ckpt_path))
    logging.info(f"Loaded tested model from {ckpt_path}")
    model.to(device)
    test(model, test_loader, device, data_dir + "/label_dict.txt")


if __name__ == "__main__":
    main()


The ViT model usually got its best validation accuracy at around 20 epochs, with the binary cross-entropy training and validation loss values both around $0.32$. After the first 20 epochs, the model starts to overfit, the training loss continues to decrease, while the validation loss starts to increase. The training process is usually stable, and the model converges well.

We compared the performance of the ViT from scratch model under different sampling strategies, and the results are shown in the table below. The evaluation is based on the model checkpoint with the best performance on the validation set.

| Sampling Strategy | Predict Acc (%) | Predict Recall (%) |
| :---------------: | :-------------: | :----------------: |
|     Original      |     62.91%      |       56.17%       |
|    Round-robin    |     92.88%      |      100.00%       |
|    Rare-first     |     100.00%     |      100.00%       |

From the table, we can see that the round-robin sampling strategy significantly improved the model's performance, achieving an accuracy of 92.88% and a recall of 100%. The rare-first sampling strategy achieved perfect accuracy and recall.