# Device Setup

In [1]:
import torch

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

# Downloading Data

In [3]:
# downloading train_dataset from dataverse.harvard.edu named Chest X-Ray Dataset for Respiratory Disease Classification
!wget https://dataverse.harvard.edu/api/access/datafile/5194114

In [4]:
!mv /content/5194114 /content/5194114.npz #renaming the file to .npz

# Imports

In [5]:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import models, transforms
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns


In [6]:
CLASS_NAMES  = ["covid", "lung_opacity", "normal", "viral_pneumonia", "tuberculosis"]
NUM_CLASSES  = 5

# Dataset

In [7]:
class ChestXrayNPZDataset(Dataset):
    """
    Keeps .npz file open for the lifetime of the dataset.
    Stores direct mmap references to image and label arrays to avoid
    repeated key lookup overhead on every __getitem__ call.
    """
    def __init__(self, npz_path, transform=None):
        self.transform = transform
        self.data      = np.load(npz_path, mmap_mode='r')  # open once, stay open
        self.images    = self.data['image']                 # direct array reference, no copy
        self.labels    = self.data['image_label']           # direct array reference, no copy
        self.length    = self.images.shape[0]               # store length from images directly
        print(f"Dataset ready: {self.length} images")

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        image = torch.tensor(np.array(self.images[idx]), dtype=torch.float32)
        label = torch.tensor(self.labels[idx].item(), dtype=torch.long)

        # Permute (H, W, C) -> (C, H, W)
        image = image.permute(2, 0, 1)

        # Convert RGB to grayscale (C, H, W) -> (1, H, W)
        image = image.mean(dim=0, keepdim=True)

        # Normalize to [0, 1]
        if image.max() > 1.0:
            image = image / 255.0

        if self.transform:
            image = self.transform(image)

        return image, label

    def __del__(self):
        self.data.close()   # close file handle only when dataset is destroyed

# Transforms

In [8]:

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.Normalize(mean=[0.485], std=[0.229])
])



# Load Data & Split into Train / Val / Test

In [9]:
NPZ_PATH   = "/content/5194114.npz" 
BATCH_SIZE = 16
VAL_SPLIT  = 0.15
TEST_SPLIT = 0.10

full_dataset = ChestXrayNPZDataset(NPZ_PATH, transform=transform)

n       = len(full_dataset)
n_val   = int(n * VAL_SPLIT)
n_test  = int(n * TEST_SPLIT)
n_train = n - n_val - n_test

train_ds, val_ds, test_ds = random_split(
    full_dataset,
    [n_train, n_val, n_test],
    generator=torch.Generator().manual_seed(42)   # reproducible split every run
)


train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

print(f"Split -> Train: {n_train} | Val: {n_val} | Test: {n_test}")




In [10]:
raw_image = np.array(full_dataset.images[0], dtype=np.float32)
raw_label = int(full_dataset.labels[0].item())

raw_image = raw_image / 255.0 if raw_image.max() > 1.0 else raw_image

plt.figure(figsize=(4, 4))
plt.imshow(raw_image.mean(axis=-1), cmap='gray')
plt.title(f"Label: {CLASS_NAMES[raw_label]}")
plt.axis('off')
plt.show()

# Model (ResNet50 adapted for grayscale)

In [11]:
class ChestXrayResNet(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()

        # Modern API — pretrained=True is deprecated since torchvision 0.13
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

        # Replace first conv: 3 channels -> 1 channel (grayscale)
        # CRITICAL FIX: average pretrained RGB weights instead of discarding them
        # This preserves the learned edge/texture features from ImageNet
        original_conv = self.resnet.conv1
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            self.resnet.conv1.weight = nn.Parameter(
                original_conv.weight.mean(dim=1, keepdim=True)
            )

        # Replace final fully-connected layer for 5-class output
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.resnet(x)


model = ChestXrayResNet(num_classes=NUM_CLASSES)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model ready. Trainable parameters: {trainable:,}")

# Training Loop

In [12]:
NUM_EPOCHS = 50
SAVE_PATH  = "best_model.pth"

def train_model(model, train_loader, val_loader, num_epochs=NUM_EPOCHS):

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)

    model.to(device)
    best_val_loss = float('inf')
    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

    for epoch in range(num_epochs):

        # -------- TRAINING PHASE --------
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)              # forward pass
            loss    = criterion(outputs, labels)
            loss.backward()                      # backward pass
            optimizer.step()

            train_loss    += loss.item() * images.size(0)
            preds          = outputs.argmax(dim=1)
            train_correct += (preds == labels).sum().item()
            train_total   += labels.size(0)

        train_loss /= train_total
        train_acc   = train_correct / train_total

        # -------- VALIDATION PHASE --------
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        with torch.no_grad():    # no gradients needed for validation
            for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]  "):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss    = criterion(outputs, labels)

                val_loss    += loss.item() * images.size(0)
                preds        = outputs.argmax(dim=1)
                val_correct += (preds == labels).sum().item()
                val_total   += labels.size(0)

        val_loss /= val_total
        val_acc   = val_correct / val_total
        scheduler.step(val_loss)

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)

        print(f"\nEpoch {epoch+1:02d}/{num_epochs} | "
              f"Train Loss: {train_loss:.4f}  Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}  Acc: {val_acc:.4f}")

        # Save only when validation loss improves (best model, not last epoch)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), SAVE_PATH)
            print(f"  -> Best model saved (val_loss={val_loss:.4f})")

    return history


history = train_model(model, train_loader, val_loader)

# Plot Training Curves

In [18]:
epochs_range = range(1, len(history["train_loss"]) + 1)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(epochs_range, history["train_loss"], label="Train Loss")
ax1.plot(epochs_range, history["val_loss"],   label="Val Loss")
ax1.set_title("Loss per Epoch")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.legend()

ax2.plot(epochs_range, history["train_acc"], label="Train Acc")
ax2.plot(epochs_range, history["val_acc"],   label="Val Acc")
ax2.set_title("Accuracy per Epoch")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.legend()

plt.tight_layout()
plt.show()


# Final Evaluation on Held-Out Test Set

In [19]:
# Always load the BEST checkpoint, not the weights from the last epoch
model.load_state_dict(torch.load(SAVE_PATH, weights_only=True))
model.to(device)
model.eval()

all_preds, all_labels = [], []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Evaluating on test set"):
        outputs = model(images.to(device))
        preds   = outputs.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

print("\n===== Classification Report =====")
print(classification_report(all_labels, all_preds, target_names=CLASS_NAMES))

# Confusion matrix heatmap
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, cmap='Blues')
plt.title("Confusion Matrix — Test Set")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.xticks(rotation=30, ha='right')
plt.tight_layout()
plt.show()


# Image Inference

In [21]:
def predict_batch(model, dataset, indices, class_names):
    """
    Run inference on multiple images sequentially and display results.
    Args:
        model:       trained model
        dataset:     dataset object to read images from
        indices:     list of indices to predict on
        class_names: list of class name strings
    """
    model.eval()
    fig, axes = plt.subplots(2, 5, figsize=(25, 10))  # 2 rows x 5 cols = 10 images
    axes = axes.flatten()

    for plot_idx, data_idx in enumerate(indices):
        # Load raw image and true label
        raw_image  = np.array(dataset.images[data_idx], dtype=np.float32)
        true_label = int(dataset.labels[data_idx].item())

        # Preprocess — same pipeline as training
        t = torch.from_numpy(raw_image)
        t = t.permute(2, 0, 1)           # (H, W, C) -> (C, H, W)
        t = t.mean(dim=0, keepdim=True)  # (C, H, W) -> (1, H, W)
        if t.max() > 1.0:
            t = t / 255.0
        t = transform(t).unsqueeze(0).to(device)

        # Inference
        with torch.no_grad():
            probs = torch.softmax(model(t), dim=1)[0]
            pred  = probs.argmax().item()

        # Plot
        correct = pred == true_label
        color   = "green" if correct else "red"
        axes[plot_idx].imshow(raw_image.mean(axis=-1), cmap='gray')
        axes[plot_idx].set_title(
            f"True:  {class_names[true_label]}\n"
            f"Pred:  {class_names[pred]} ({probs[pred]:.1%})",
            color=color, fontsize=10
        )
        axes[plot_idx].axis('off')

        # Print probabilities for each image
        print(f"\nImage {data_idx}:")
        for i, name in enumerate(class_names):
            bar = "#" * int(probs[i].item() * 30)
            print(f"  {name:<20}: {probs[i]:.4f}  {bar}")

    plt.suptitle("Batch Inference — Green = Correct, Red = Incorrect", fontsize=13)
    plt.tight_layout()
    plt.show()


import random
predict_batch(model, full_dataset, indices=random.sample(range(len(full_dataset)), 10), class_names=CLASS_NAMES)