In [1]:
# Mounting Google Drive and importing libraries
from google.colab import drive
import os
import sys

drive.mount('/content/drive', force_remount=True)

# Adding project src folder to Python path
project_drive_path = "/content/drive/MyDrive/CheXpert_Project"
sys.path.append("/content/drive/MyDrive/CheXpert_Project/src")

Mounted at /content/drive


In [2]:
import torch
import random
import numpy as np
from torch.optim import SGD
from torch.utils.data import DataLoader

from moco.utils import set_seed, get_device
from moco.dataset_loader import CheXpertMoCoDataset
from moco.data_augmentations import moco_medical_transform
from moco.model_builder import MoCoModel
from moco.train_moco import run_moco_training

In [3]:
# Selecting device and setting the seed
seed = 42
set_seed(seed)

device = get_device()
print("Device:", device)

Device: cuda


Note: change this to copy a zip (or look into Google Cloud Storage buckets. Apparently, you can copy a files in a bucket to a colab content folder at 900 MB/s)

In [4]:
!mkdir -p /content/CheXpert_Data
!cp -r /content/drive/MyDrive/CheXpert_Data/train /content/CheXpert_Data/
!cp -r /content/drive/MyDrive/CheXpert_Data/val /content/CheXpert_Data/
!cp -r /content/drive/MyDrive/CheXpert_Data/test /content/CheXpert_Data/
!cp /content/drive/MyDrive/CheXpert_Data/final_project_updated_names_* /content/CheXpert_Data/


In [13]:
image_root = "/content/CheXpert_Data/"
train_csv_path = "/content/CheXpert_Data/final_project_updated_names_train.csv"

print("CheXpert Root:", image_root)
print("Train CSV:", train_csv_path)

train_root = "/content/CheXpert_Data/train"

CheXpert Root: /content/CheXpert_Data/
Train CSV: /content/CheXpert_Data/final_project_updated_names_train.csv


In [14]:
# Loading dataset and augmentations for MoC
train_dataset = CheXpertMoCoDataset(
    csv_path=train_csv_path,
    root_directory=train_root,
    use_augmentations=True
)

# Creating dataset and dataloader
train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    num_workers=2,
    drop_last=True
)


len(train_dataset)

14032

In [15]:
embedding_dimension = 128
queue_size = 65536
momentum_update_rate = 0.999
softmax_temperature = 0.07

model = MoCoModel(
    embedding_dim=embedding_dimension,
    queue_size=queue_size,
    momentum=momentum_update_rate,
    temperature=softmax_temperature
).to(device)

optimizer = SGD(
    model.encoder_query.parameters(),
    lr=0.03,
    weight_decay=1e-4,
    momentum=0.9
)

In [9]:
checkpoint_dir = "/content/drive/MyDrive/CHEXPERT_PROJECT/checkpoints"
log_dir = "/content/drive/MyDrive/CHEXPERT_PROJECT/logs"

os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

csv_log_path = f"{log_dir}/moco_training_log.csv"

In [16]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import csv
import time

# Writing a function to train the MoCo model for one epoch
def train_one_epoch_moco(model, data_loader, optimizer, device):
    model.train()
    total_loss = 0.0
    loss_fn = nn.CrossEntropyLoss()

    positive_logit_values = []
    negative_logit_values = []


    print("Starting slow batch test...")
    for i, batch in enumerate(data_loader):
        s = time.time()

        images_query, images_key, _ = batch
        images_query = images_query.to(device)
        images_key = images_key.to(device)

        fwd_start = time.time()
        logits, labels = model(images_query, images_key)
        fwd = time.time() - fwd_start

        bwd_start = time.time()
        loss = nn.CrossEntropyLoss()(logits, labels)
        loss.backward()
        bwd = time.time() - bwd_start

        print(f"[{i}] Forward: {fwd:.3f}s | Backward: {bwd:.3f}s")

        break


    for batch_index, batch in enumerate(data_loader):
        t0 = time.time()
        # Dataloader timing
        print(f"[{batch_index}] DataLoader time: {time.time() - t0:.4f}s")
        images_query, images_key, _ = batch
        images_query = images_query.to(device)
        images_key = images_key.to(device)

        t1 = time.time()
        logits, labels = model(images_query, images_key)
        print(f"[{batch_index}] Forward time: {time.time() - t1:.4f}s")

        # Computing the loss using cross entropy
        loss = loss_fn(logits, labels)

        # Performing the backward pass
        t2 = time.time()
        optimizer.zero_grad()
        loss.backward()
        print(f"[{batch_index}] Backward time: {time.time() - t2:.4f}s")

        t3 = time.time()
        optimizer.step()
        print(f"[{batch_index}] Step time: {time.time() - t3:.4f}s")

        total_loss += loss.item()

        positives = logits[:, 0].detach().cpu()
        negatives = logits[:, 1:].mean(dim=1).detach().cpu()

        positive_logit_values.extend(positives.tolist())
        negative_logit_values.extend(negatives.tolist())

    avg_pos = sum(positive_logit_values) / len(positive_logit_values)
    avg_neg = sum(negative_logit_values) / len(negative_logit_values)

    return total_loss / len(data_loader), avg_pos, avg_neg


# Writing a full MoCo training loop with epochs
def run_moco_training(model, data_loader, optimizer, device, epochs, save_path=None, csv_log_path=None):
    loss_history = []

    if csv_log_path is not None:
        with open(csv_log_path, "w", newline="") as file:
            writer = csv.writer(file)
            writer.writerow([
                "epoch",
                "epoch_loss",
                "learning_rate",
                "queue_mean",
                "queue_std",
                "avg_positive_logit",
                "avg_negative_logit",
                "epoch_seconds"
            ])

    for epoch in range(epochs):
        print(f"START EPOCH {epoch+1}")
        start_time = time.time()

        epoch_loss, avg_pos, avg_neg = train_one_epoch_moco(model, data_loader, optimizer, device)
        loss_history.append(epoch_loss)

        # Getting the learning rate for this epoch
        lr = optimizer.param_groups[0]["lr"]

        # Getting the queue statistics
        queue = model.queue.detach().cpu()
        queue_mean = queue.mean().item()
        queue_std = queue.std().item()

        epoch_seconds = time.time() - start_time

        # Time calculations
        elapsed = time.time() - start_time
        remaining_epochs = epochs - (epoch + 1)
        eta_seconds = remaining_epochs * elapsed

        # ETA formatting
        eta_min = eta_seconds / 60
        eta_hr  = eta_min / 60

        if eta_hr >= 1:
            eta_str = f"{eta_hr:.1f}h"
        else:
            eta_str = f"{eta_min:.1f}m"

        # Printing out Epoch Statistics in Notebook
        print(
            f"Epoch {epoch+1}/{epochs} | "
            f"Loss: {epoch_loss:.4f} | "
            f"Pos: {avg_pos:.3f} | "
            f"Neg: {avg_neg:.3f} | "
            f"Queue μ: {queue_mean:.3f} | "
            f"Queue σ: {queue_std:.3f} | "
            f"Time: {elapsed:.1f}s | ETA: {eta_str}"
        )

        if save_path is not None:
            checkpoint_path = f"{save_path}/moco_epoch_{epoch+1}.pth"
            torch.save(model.state_dict(), checkpoint_path)

        if csv_log_path is not None:
            with open(csv_log_path, "a", newline="") as file:
                writer = csv.writer(file)
                writer.writerow([
                    epoch+1,
                    epoch_loss,
                    lr,
                    queue_mean,
                    queue_std,
                    avg_pos,
                    avg_neg,
                    epoch_seconds
                ])

    return loss_history

In [11]:
print(len(train_loader.dataset))

14032


In [None]:
history = run_moco_training(
    model=model,
    data_loader=train_loader,
    optimizer=optimizer,
    device=device,
    epochs=50,
    save_path=checkpoint_dir,
    csv_log_path=csv_log_path
)
print("Training Finished!")

START EPOCH 1
Starting slow batch test...
[0] Forward: 1.281s | Backward: 0.649s
[0] DataLoader time: 0.0000s
[0] Forward time: 0.2093s
[0] Backward time: 0.0149s
[0] Step time: 0.1715s
[1] DataLoader time: 0.0000s
[1] Forward time: 0.1958s
[1] Backward time: 0.0122s
[1] Step time: 0.1687s
[2] DataLoader time: 0.0000s
[2] Forward time: 0.2085s
[2] Backward time: 0.0127s
[2] Step time: 0.0018s
[3] DataLoader time: 0.0000s
[3] Forward time: 0.1958s
[3] Backward time: 0.0122s
[3] Step time: 0.0019s
[4] DataLoader time: 0.0000s
[4] Forward time: 0.2138s
[4] Backward time: 0.0120s
[4] Step time: 0.0018s
[5] DataLoader time: 0.0000s
[5] Forward time: 0.1957s
[5] Backward time: 0.0121s
[5] Step time: 0.0018s
[6] DataLoader time: 0.0000s
[6] Forward time: 0.2139s
[6] Backward time: 0.0144s
[6] Step time: 0.0019s
[7] DataLoader time: 0.0000s
[7] Forward time: 0.1957s
[7] Backward time: 0.0129s
[7] Step time: 0.0018s
[8] DataLoader time: 0.0000s
[8] Forward time: 0.2138s
[8] Backward time: 0.012

In [None]:
import pandas as pd
df = pd.read_csv(csv_log_path)
df.head()