# Model Definition for Chest X-Ray Classification

This notebook defines the model architecture using a pretrained ResNet-18 model. The final fully connected layer is modified to output 5 classes corresponding to the different respiratory diseases.

In [1]:
import torch.nn as nn
from torchvision.models.resnet import resnet18, ResNet18_Weights

In [2]:
def create_model(num_classes=5):
    """
    Create a ResNet-18 model with a modified final layer for the given number of classes.
    
    Args:
        num_classes (int): Number of output classes.
    
    Returns:
        model (nn.Module): Modified ResNet-18 model.
    """
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_ftrs, num_classes)  # 5 classes
    )
    return model