# Project Baseline Setup:

### Architecture
Run pre-training with contrastive learning (MoCo framework)
* Use backbone - `ViT-S/16`
* **Perform SSL Pre-training** of backbone using contrastive learning (MoCo) -> augment medical imaging dataset (Chest XRs) (Pre-text task); create positive and negative pairs
  * Produce: `moco_vits16_encoder.pth`
* **Transfer Learning:** Fine-tune pre-trained ResNet for Pneumonia Chest XR classification
  * Produce: `finetuned_vits16_medical.pth`

### Pre-training Dataset: CheXpert
* Subset: Pneumonia classification only; smaller dataset (to accomodate class imbalance)

### Fine-tuning Dataset: NIH Pneumonia Dataset

###

In [1]:
# import libraries
import os
import sys
import argparse
from tqdm import tqdm

import numpy as np
import seaborn as sns
import pandas as pd
import math
from copy import deepcopy

%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import display

import pickle
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import random_split, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Collab Needs

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Define filepaths to required input (scripts, data) and outputs

In [4]:
# ----------------------------------------------------
# Inputs root
# ----------------------------------------------------
FP_ROOT="/content/drive/MyDrive/ViT_MoCo_Project/"

# ----------------------------------------------------
# Dataset info
# ----------------------------------------------------
# Chexpert
DATASET_PATH_SPLIT = FP_ROOT + "Data/CheXpert_reduced_dataset_split_v3.zip"

# ChexPert
TRAIN_LABELS_CSV = FP_ROOT + "Data/1_final_project_updated_names_train_moco.csv"
TEST_LABELS_CSV = FP_ROOT + "Data/1_final_project_updated_names_test_moco.csv"

LINEAR_TRAIN_LABELS_CSV = FP_ROOT + "Data/2_final_project_updated_names_train_linear.csv"
LINEAR_TEST_LABELS_CSV = FP_ROOT + "Data/2_final_project_updated_names_test_linear.csv"

# ----------------------------------------------------
# Model SRC
# ----------------------------------------------------
SRC_ROOT = FP_ROOT + "src/"
TRAIN_SCRIPT = f"{SRC_ROOT}/train_moco_unified.py"
MOCO_FOLDER = f"{SRC_ROOT}/moco"

# ----------------------------------------------------
# Outputs
# ----------------------------------------------------
ROOT_ARTIFACT_SAVE = FP_ROOT + "artifacts/vit_baseline/"

In [5]:
# Add the project /src to the system path
if SRC_ROOT not in sys.path:
    sys.path.append(SRC_ROOT)
    print(f"Added {SRC_ROOT} to sys.path")

Added /content/drive/MyDrive/ViT_MoCo_Project/src/ to sys.path


In [6]:
print(sys.path)

['/content', '/env/python', '/usr/lib/python312.zip', '/usr/lib/python3.12', '/usr/lib/python3.12/lib-dynload', '', '/usr/local/lib/python3.12/dist-packages', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.12/dist-packages/IPython/extensions', '/root/.ipython', '/tmp/tmph85ibkmc', '/content/drive/MyDrive/ViT_MoCo_Project/src/']


## Unzip data

ChexPert

In [7]:
# Unzip the dataset (image) files to /tmp
DATA_DEST_UNZIPPED = "/tmp/CheXpert_dataset/"
os.makedirs(DATA_DEST_UNZIPPED, exist_ok=True)
print("Extracting dataset...")
!unzip -q "{DATASET_PATH_SPLIT}" -d {DATA_DEST_UNZIPPED}
print("Dataset extracted to:", DATA_DEST_UNZIPPED)

Extracting dataset...
Dataset extracted to: /tmp/CheXpert_dataset/


In [8]:
# Update for the unzipped sub-name
DATA_DEST_UNZIPPED = "/tmp/CheXpert_dataset/CheXpert_reduced_dataset_split_v3/"

## 1) Run Pre-training - Contrastive Learning

In [9]:
# Run from the src directory
%cd "/content/drive/MyDrive/ViT_MoCo_Project/src"

/content/drive/MyDrive/ViT_MoCo_Project/src


In [10]:
import os
os.environ["PYTHONPATH"] = "/content/drive/MyDrive/ViT_MoCo_Project/src"

In [11]:
! python moco/train_moco_unified.py \
    --precision fp16 \
    --train_csv_path "$TRAIN_LABELS_CSV" \
    --root_dir "$DATA_DEST_UNZIPPED" \
    --artifact_root "$ROOT_ARTIFACT_SAVE" \
    --model_type "VIT_S16" \
    --batch_size 64 \
    --n_epochs 100 \
    --num_workers 12 \
    --out_model_name "vit_s16_moco_encoder_v1.pth"

Created log file:  /content/drive/MyDrive/ViT_MoCo_Project/artifacts/vit_baseline/moco_training_log_20251209_123855.txt
[2025-12-09 12:38:55] Creating MoCo medical image Train DataLoader... : from /content/drive/MyDrive/ViT_MoCo_Project/Data/1_final_project_updated_names_train_moco.csv
Loading dataset from: /content/drive/MyDrive/ViT_MoCo_Project/Data/1_final_project_updated_names_train_moco.csv
Training Dataset size: 49500 images
[2025-12-09 12:38:57] Using backbone: VIT_S16
Using timm ViT backbone: vit_small_patch16_224
[2025-12-09 12:38:58] Starting Unified MoCo training for 100 epochs...
[2025-12-09 12:38:58] Precision mode = fp16
  with autocast(dtype=dtype_autocast):
Epoch 0/99: 100% 773/773 [03:37<00:00,  3.55it/s, loss=6.2, pos_cos=0.998]
[2025-12-09 12:42:35] Epoch complete 0, Avg. Running Loss: 6.435738
[2025-12-09 12:42:35] POS_COS=0.9503
[2025-12-09 12:42:35] Saving initial checkpoint at epoch 1; Avg. Running Loss: 6.435738
[2025-12-09 12:42:50] Saving checkpoint at epoch 1

In [None]:
class PreNormEncoderLayer(nn.Module):
    def __init__(self, embed_dim, heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, heads, dropout=dropout, batch_first=True)

        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x


class SimpleTransformer(nn.Module):
    def __init__(self, embed_dim=512, depth=12, heads=8, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            PreNormEncoderLayer(embed_dim, heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

## Test Backbone with Linear Classifier

In [None]:
# Run from the src directory
%cd "/content/drive/MyDrive/ViT_MoCo_Project/src"

# Hyperparams
batch_size = 128

In [None]:
! python moco/test_moco_vit_hybrid.py \
    --root_dir "$DATA_DEST_UNZIPPED" \
    --artifact_root "$ROOT_ARTIFACT_SAVE" \
    --test_num_classes 2 \
    --linear_n_epochs 10 \
    --linear_train_csv_path "$LINEAR_TRAIN_LABELS_CSV" \
    --linear_test_csv_path "$LINEAR_TEST_LABELS_CSV" \
    --model_checkpoint "vit_hybrid_moco_checkpoint_epoch_37.pth" \
    --batch_size "$batch_size"

Created log file:  /content/drive/MyDrive/ViT_MoCo_Project/artifacts/moco_backbone_testing_log_20251207_114444.txt
[2025-12-07 11:44:44] Starting MoCo backbone testing...
[2025-12-07 11:44:44] Loading pretrained MoCo model from /content/drive/MyDrive/ViT_MoCo_Project/artifacts/vit_hybrid_moco_checkpoint_epoch_37.pth...
[2025-12-07 11:44:49] Detected 'model_state' in checkpoint.
[2025-12-07 11:44:49] Loaded model. Missing keys: [], unexpected keys: []
[2025-12-07 11:44:49] Loaded pretrained encoder. missing keys: [], unexpected: []
Loading train dataset...
CSV: /content/drive/MyDrive/ViT_MoCo_Project/Data/2_final_project_updated_names_train_linear.csv
 * Images - Train Root Directory: /tmp/CheXpert_dataset/CheXpert_reduced_dataset_split_v3/train
Unique labels in column 'Pneumonia': [1 0]
Loading test dataset...
CSV: /content/drive/MyDrive/ViT_MoCo_Project/Data/2_final_project_updated_names_test_linear.csv
 * Images - Test Root Directory: /tmp/CheXpert_dataset/CheXpert_reduced_dataset_sp

In [1]:
def visualize_patch_importance(model, dataloader, device, k=10):
    import numpy as np
    import matplotlib.pyplot as plt
    import cv2
    import torch

    # 1) Get one sample from loader
    model.eval()
    images, labels = next(iter(dataloader))
    img = images[0].to(device)              # (3,224,224)

    with torch.no_grad():
        logits, patch_scores = model(img.unsqueeze(0))

    patch_scores = patch_scores[0]          # (196,)

    # 2) Prep score data
    scores = patch_scores.detach().cpu().numpy()
    topk_idx = np.argsort(scores)[-k:]      # top-K patch indices

    # Convert indices → patch grid coords
    coords = [(i // 14, i % 14) for i in topk_idx]

    # 3) Convert image for drawing
    img_np = img.cpu().permute(1,2,0).numpy()
    img_np = (img_np * 255).astype(np.uint8)
    img_draw = img_np.copy()

    patch_size = 16   # for ViT-B/16

    # 4) Draw red bounding boxes
    for (r, c) in coords:
        y1, x1 = r * patch_size, c * patch_size
        y2, x2 = y1 + patch_size, x1 + patch_size
        cv2.rectangle(img_draw, (x1, y1), (x2, y2), (255,0,0), 2)

    # 5) Display result
    plt.figure(figsize=(6,6))
    plt.imshow(img_draw)
    plt.axis("off")
    plt.title(f"Top {k} Highest-Scoring Patches (Red Boxes)")
    plt.show()

    return img_draw, scores

_ = visualize_patch_importance(trained_model, val_loader, device, k=1)

NameError: name 'trained_model' is not defined

In [None]:
def visualize_patch_scores(image_tensor, scores):
    # Converting image from tensor to numpy for OpenCV
    img = image_tensor.permute(1, 2, 0).cpu().numpy()
    img = (img * 255).astype(np.uint8)

    # Normalizing patch scores (so color map has meaningful scale)
    s = scores.detach().cpu().numpy()
    s = (s - s.min()) / (s.max() - s.min() + 1e-8)

    # Reshaping to ViT patch grid (14x14 for ViT-B/16)
    s = s.reshape(14, 14)

    # Upscaling patch heatmap to full image resolution
    s = cv2.resize(s, (224, 224))

    # Generating heatmap overlay
    heatmap = cv2.applyColorMap((s * 255).astype(np.uint8), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)

    # Combining heatmap with original image (weighted blend)
    overlay = (0.5 * img + 0.5 * heatmap).astype(np.uint8)

    # Displaying both images side-by-side
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.axis("off")
    plt.title("Original Image")

    plt.subplot(1, 2, 2)
    plt.imshow(overlay)
    plt.axis("off")
    plt.title("Patch Score Heatmap (Rainbow)")


    plt.show()

# === VISUALIZE PATCH SCORES ON ONE VALIDATION SAMPLE ===

# 1. Rebuild validation loader exactly like in training
_, val_loader = build_dataloaders()

# 2. Get one batch
images, labels = next(iter(val_loader))

# 3. Pick first image
img = images[0].to(device)

# 4. Run model forward pass to get logits + patch_scores
model.eval()
with torch.no_grad():
    logits, patch_scores = model(img.unsqueeze(0))

# 5. IMPORTANT: remove batch dimension and move to CPU
patch_scores = patch_scores[0]    # shape: (196,)
img_vis = img.cpu()               # already in [0,1] because you did NOT normalize

# 6. Visualize patch importance heatmap
visualize_patch_scores(img_vis, patch_scores)


In [None]:
import matplotlib.cm as cm

def visualize_patch_scores_clinical(image_tensor, scores):
    import numpy as np
    import matplotlib.pyplot as plt
    import cv2
    import torch

    # Convert image to uint8
    img = image_tensor.permute(1,2,0).cpu().numpy()
    img = (img * 255).astype(np.uint8)

    # Normalize scores
    s = scores.detach().cpu().numpy()
    s = (s - s.min()) / (s.max() - s.min() + 1e-8)

    # 14x14 → 224x224 smooth upsample
    heat = s.reshape(14, 14)
    heat = cv2.resize(heat, (224, 224), interpolation=cv2.INTER_LINEAR)

    # Use perceptual colormap (inferno = clinically preferred)
    cmap = cm.inferno
    heat_color = (cmap(heat)[..., :3] * 255).astype(np.uint8)

    # Blend with image
    overlay = (0.6 * img + 0.4 * heat_color).astype(np.uint8)

    # Display
    fig, ax = plt.subplots(1,2, figsize=(10,5))
    ax[0].imshow(img)
    ax[0].axis("off")
    ax[0].set_title("Original Image")

    ax[1].imshow(overlay)
    ax[1].axis("off")
    ax[1].set_title("Patch Score Heatmap (Inferno)")


    plt.show()
visualize_patch_scores_clinical(img_vis, patch_scores)

def visualize_patch_scores_clinical(image_tensor, scores):
    import numpy as np
    import matplotlib.pyplot as plt
    import cv2
    import matplotlib.cm as cm
    from matplotlib import gridspec

    # --- Convert image to displayable uint8 ---
    img = image_tensor.permute(1,2,0).cpu().numpy()
    img = (img * 255).astype(np.uint8)

    # --- Normalize patch scores but keep true min/max for colorbar ---
    s = scores.detach().cpu().numpy()
    s_min, s_max = s.min(), s.max()
    s_norm = (s - s_min) / (s_max - s_min + 1e-8)

    # --- 14x14 → 224x224 heatmap ---
    heat = s_norm.reshape(14, 14)
    heat = cv2.resize(heat, (224, 224), interpolation=cv2.INTER_LINEAR)

    # --- Colorize ---
    cmap = cm.inferno
    heat_color = (cmap(heat)[..., :3] * 255).astype(np.uint8)

    # --- Blend with original ---
    overlay = (0.6 * img + 0.4 * heat_color).astype(np.uint8)

    # --- Grid layout to prevent squeezing ---
    fig = plt.figure(figsize=(12,6))
    gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 0.05])

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])
    cax = plt.subplot(gs[2])  # dedicated colorbar axis

    # Left: Original
    ax0.imshow(img)
    ax0.axis("off")
    ax0.set_title("Original Image", fontsize=18)

    # Right: Heatmap overlay
    im = ax1.imshow(overlay)
    ax1.axis("off")
    ax1.set_title("Patch Score Heatmap (Inferno)", fontsize=18)

    # Colorbar with true raw score scale
    norm = plt.Normalize(vmin=0.0, vmax=1.0)
    sm = cm.ScalarMappable(norm=norm, cmap="inferno")
    sm.set_array([])

    cb = plt.colorbar(sm, cax=cax)

    plt.tight_layout()
    plt.show()


visualize_patch_scores_clinical(img_vis, patch_scores)

def visualize_patch_scores_clinical(image_tensor, scores):
    import numpy as np
    import matplotlib.pyplot as plt
    import cv2
    import matplotlib.cm as cm
    from matplotlib import gridspec

    # --- Convert image to displayable uint8 ---
    img = image_tensor.permute(1,2,0).cpu().numpy()
    img = (img * 255).astype(np.uint8)

    # --- Normalize patch scores but keep true min/max for colorbar ---
    s = scores.detach().cpu().numpy()
    s_min, s_max = s.min(), s.max()
    s_norm = (s - s_min) / (s_max - s_min + 1e-8)

    # --- 14x14 → 224x224 heatmap ---
    heat = s_norm.reshape(14, 14)
    heat = cv2.resize(heat, (224, 224), interpolation=cv2.INTER_LINEAR)

    # --- Colorize ---
    cmap = cm.jet
    heat_color = (cmap(heat)[..., :3] * 255).astype(np.uint8)

    # --- Blend with original ---
    overlay = (0.6 * img + 0.4 * heat_color).astype(np.uint8)

    # --- Grid layout to prevent squeezing ---
    fig = plt.figure(figsize=(12,6))
    gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 0.05])

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])
    cax = plt.subplot(gs[2])  # dedicated colorbar axis

    # Left: Original
    ax0.imshow(img)
    ax0.axis("off")
    ax0.set_title("Original Image", fontsize=18)

    # Right: Heatmap overlay
    im = ax1.imshow(overlay)
    ax1.axis("off")
    ax1.set_title("Patch Score Heatmap", fontsize=24)


    # Colorbar with true raw score scale
    norm = plt.Normalize(vmin=0.0, vmax=1.0)
    sm = cm.ScalarMappable(norm=norm, cmap="jet")
    sm.set_array([])

    cb = plt.colorbar(sm, cax=cax)

    plt.tight_layout()
    plt.show()


visualize_patch_scores_clinical(img_vis, patch_scores)

def visualize_patch_scores_clinical(image_tensor, scores):
    import numpy as np
    import matplotlib.pyplot as plt
    import cv2
    import matplotlib.cm as cm
    from matplotlib import gridspec

    # Convert image to displayable uint8 ---
    img = image_tensor.permute(1,2,0).cpu().numpy()
    img = (img * 255).astype(np.uint8)

    # Normalize patch scores
    s = scores.detach().cpu().numpy()
    s_norm = (s - s.min()) / (s.max() - s.min() + 1e-8)

    # 14x14 → 224x224 heatmap
    heat = s_norm.reshape(14, 14)
    heat = cv2.resize(heat, (224, 224), interpolation=cv2.INTER_LINEAR)

    # Colorize
    cmap = cm.inferno
    heat_color = (cmap(heat)[..., :3] * 255).astype(np.uint8)

    # Blend with original
    overlay = (0.6 * img + 0.4 * heat_color).astype(np.uint8)

    # Grid layout with 2 main axes only
    fig = plt.figure(figsize=(12,6))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1])

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])

    # Original
    ax0.imshow(img)
    ax0.axis("off")
    ax0.set_title("Original Image", fontsize=20)

    # Heatmap overlay
    im = ax1.imshow(overlay)
    ax1.axis("off")
    ax1.set_title("Patch Score Heatmap (Inferno)", fontsize=20)

    # Colorbar
    norm = plt.Normalize(vmin=0.0, vmax=1.0)
    sm = cm.ScalarMappable(norm=norm, cmap="inferno")
    sm.set_array([])

    # Colorbar same height as heatmap ---
    cb = fig.colorbar(sm, ax=ax1, fraction=0.046, pad=0.04)
    cb.set_label("Normalized Patch Importance", fontsize=14)

    plt.tight_layout()
    plt.show()


visualize_patch_scores_clinical(img_vis, patch_scores)

def visualize_all_patches(image_tensor, patch_scores):
    import numpy as np
    import matplotlib.pyplot as plt
    import cv2
    import matplotlib.cm as cm

    img = image_tensor.permute(1,2,0).cpu().numpy()
    img = (img * 255).astype(np.uint8)

    # Normalize patch scores
    s = patch_scores.detach().cpu().numpy()
    s_norm = (s - s.min()) / (s.max() - s.min() + 1e-8)
    grid = s_norm.reshape(14,14)

    cmap = cm.inferno
    patch_size = 16
    img_draw = img.copy()

    # DRAW BORDERS
    for r in range(14):
        for c in range(14):
            score = grid[r,c]

            rgb = np.array(cmap(score)[:3]) * 255
            rgb = rgb.astype(np.uint8)
            color = (int(rgb[0]), int(rgb[1]), int(rgb[2]))

            y1, x1 = r * patch_size, c * patch_size
            y2, x2 = y1 + patch_size, x1 + patch_size

            cv2.rectangle(img_draw, (x1, y1), (x2, y2), color, 1)

    # PLOT WITH COLORBAR
    fig, ax = plt.subplots(figsize=(7,7))

    im = ax.imshow(img_draw)
    ax.axis("off")
    ax.set_title("All Patches (Score-Based Colored Borders)", fontsize=16)

    # Create colorbar using normalized scale (0–1)
    norm = plt.Normalize(vmin=0.0, vmax=1.0)
    sm = cm.ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([])

    # add colorbar matching height of the image
    cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)

    plt.show()

    return img_draw

visualize_all_patches(img_vis, patch_scores)