# Install, Import, and Mount Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! pip install --upgrade --quiet bitsandbytes datasets peft transformers python-docx

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m37.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m515.2/515.2 kB[0m [31m40.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.0/557.0 kB[0m [31m42.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m139.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.0/253.0 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m43.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import re
import json
import PIL

from typing import Any, Dict, Literal
from collections import defaultdict
from PIL import Image as PILImage
from tqdm import tqdm
from docx import Document


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from transformers import (
    AutoProcessor,
    PaliGemmaForConditionalGeneration, AutoModelForImageTextToText,
    BitsAndBytesConfig, get_constant_schedule_with_warmup
)
from datasets import load_dataset, concatenate_datasets, Image
from peft import get_peft_model, LoraConfig, TaskType
import wandb


# Preprocess Data and Save to Drive
Skip this section, if next section (Load Data from Drive) works without issue

In [None]:
# Function to parse text report into dict, easier for formating later
def parse_doc_into_dict(document_obj):
  document = [paragraph.text for paragraph in document_obj.paragraphs]
  document_string = '\n'.join(document)

  patient_id = re.search(r'PATIENT NO\.(.*)',document_string)
  acr_category = re.search(r"(ACR.*)", document_string)
  right_breast_first_finding = re.search(r"Right Breast:(.*?)\n\n", document_string, flags=re.DOTALL)
  left_breast_first_finding = re.search(r"Left Breast:(.*?)\n\n", document_string, flags=re.DOTALL)
  right_breast_opinion = re.search(r"OPINION:.*?Right Breast:(.*?)\n\n", document_string, flags=re.DOTALL)
  left_breast_opinion = re.search(r"OPINION:.*?Left Breast:(.*?)\n\n", document_string, flags=re.DOTALL)
  right_breast_second_finding = re.search(r"CONTRAST ENHANCED SPECTRAL.*?Right Breast:\n{1,2}(.*?)(\n\n|$)", document_string, flags=re.DOTALL)
  left_breast_second_finding = re.search(r"CONTRAST ENHANCED SPECTRAL.*?Left Breast:\n{1,2}(.*?)(\n\n|$)", document_string, flags=re.DOTALL)

  return {
      'patient_id': patient_id.group(1).strip() if patient_id else "",
      'acr_category': acr_category.group(1).strip() if acr_category else "",
      'right_breast_first_finding': right_breast_first_finding.group(1).strip() if right_breast_first_finding else "",
      'left_breast_first_finding': left_breast_first_finding.group(1).strip() if left_breast_first_finding else "",
      'right_breast_opinion': right_breast_opinion.group(1).strip() if right_breast_opinion else "",
      'left_breast_opinion': left_breast_opinion.group(1).strip() if left_breast_opinion else "",
      'right_breast_second_finding': right_breast_second_finding.group(1).strip() if right_breast_second_finding else "",
      'left_breast_second_finding': left_breast_second_finding.group(1).strip() if left_breast_second_finding else "",
  }

# Parse all report and saved into a list of dictionary
report_root_dir = "/content/drive/MyDrive/Dataset/Medical-reports-for-cases-/Medical reports for cases"
files = [os.path.join(report_root_dir, f) for f in os.listdir(report_root_dir) if f.lower().endswith('.docx')]


# parsed_report: Dict[patient_id, patient_info]
# patient_info: Dict[data_key, data_value]
parsed_reports: Dict[str, Dict[str, str]] = {}

for file in tqdm(files):
    document_obj = Document(file)
    document_dict = parse_doc_into_dict(document_obj)
    patient_id = document_dict.pop('patient_id')
    parsed_reports[patient_id] = document_dict

# Example
parsed_reports.get('176')

100%|██████████| 326/326 [00:26<00:00, 12.29it/s]


{'acr_category': 'ACR C: Heterogenously dense breasts.',
 'right_breast_first_finding': 'Regional microcalcifications with clusters and edema pattern are noted.',
 'left_breast_first_finding': 'No speculated mass lesions or suspicious microcalcifications.\nNormal skin thickness and contour of breast.',
 'right_breast_opinion': 'Regional suspicious looking microcalcifications associated with edema pattern (BIRADS 4).',
 'left_breast_opinion': 'Normal breast examination (BIRADS 1).',
 'right_breast_second_finding': 'Upper outer faint linear non mass enhancement (BIRADS 4).',
 'left_breast_second_finding': 'No mass or non mass enhancement (BIRADS 1).'}

In [None]:
# Extract only portion of the report relevant to the image and format it into natural language strings
def extract_and_generate_report(left_or_right: Literal["L","R"],
                                mammography_type: Literal["CM", "DM"],
                                report_dict: dict[str, str]) -> str:

    acr_line = f"{report_dict['acr_category'] if mammography_type == 'DM' else ''}" # Only include ACR information for Low energy image

    findings_line = ""

    if left_or_right == "R" and mammography_type == "DM":
        findings_line = f"Findings: \n{report_dict['right_breast_first_finding']}\n\nOpinion: \n{report_dict['right_breast_opinion']}"
    elif left_or_right == "R" and mammography_type == "CM":
        findings_line = f"Findings: \n{report_dict['right_breast_second_finding']}"
    elif left_or_right == "L" and mammography_type == "DM":
        findings_line = f"Findings: \n{report_dict['left_breast_first_finding']}\n\nOpinion: \n{report_dict['left_breast_opinion']}"
    elif left_or_right == "L" and mammography_type == "CM":
        findings_line = f"Findings: \n{report_dict['left_breast_second_finding']}"

    generated_report = f"{acr_line}\n\n{findings_line}".strip()

    return generated_report

# Example
print(extract_and_generate_report("L", "CM", parsed_reports.get('176')))

Findings: 
No mass or non mass enhancement (BIRADS 1).


In [None]:
images_root_dir = '/content/drive/MyDrive/CDD-CESM-curated-dataset/images'

# Group images by patient_id
patient_data = defaultdict(lambda: {
    "patient_id": None,
    "image_l_cc_dm": None,
    "image_l_cc_cm": None,
    "image_l_mlo_dm": None,
    "image_l_mlo_cm": None,
    "image_r_cc_dm": None,
    "image_r_cc_cm": None,
    "image_r_mlo_dm": None,
    "image_r_mlo_cm": None,
    "report_l_dm": None,
    "report_r_dm": None,
    "report_l_cm": None,
    "report_r_cm": None,
})

for image_file_name in tqdm(os.listdir(images_root_dir)):
    if not image_file_name.lower().endswith('.jpg'):
        continue

    # Extract metadata
    patient_id = re.search(r"P(.*?)_", image_file_name).group(1)
    left_or_right = re.search(r".*?_([LR])_", image_file_name).group(1)
    cc_or_mlo = re.search(r".*?_(CC|MLO)", image_file_name).group(1)
    mammography_type = re.search(r".*?_(CM|DM)_", image_file_name).group(1)

    # Build keys
    image_key = f"image_{left_or_right}_{cc_or_mlo}_{mammography_type}".lower()
    report_key = f"report_{left_or_right}_{mammography_type}".lower()

    # Path data & report data
    image_full_path = os.path.join(images_root_dir, image_file_name)
    report_dict = parsed_reports.get(patient_id)
    report = extract_and_generate_report(left_or_right, mammography_type, report_dict)

    # Store in patient data
    patient_data[patient_id]["patient_id"] = patient_id
    patient_data[patient_id][image_key] = image_full_path
    patient_data[patient_id][report_key] = report

# Example
patient_data.get('176')

100%|██████████| 2007/2007 [00:00<00:00, 12034.08it/s]


{'patient_id': '176',
 'image_l_cc_dm': '/content/drive/MyDrive/CDD-CESM-curated-dataset/images/P176_L_DM_CC.jpg',
 'image_l_cc_cm': '/content/drive/MyDrive/CDD-CESM-curated-dataset/images/P176_L_CM_CC.jpg',
 'image_l_mlo_dm': '/content/drive/MyDrive/CDD-CESM-curated-dataset/images/P176_L_DM_MLO.jpg',
 'image_l_mlo_cm': '/content/drive/MyDrive/CDD-CESM-curated-dataset/images/P176_L_CM_MLO.jpg',
 'image_r_cc_dm': '/content/drive/MyDrive/CDD-CESM-curated-dataset/images/P176_R_DM_CC.jpg',
 'image_r_cc_cm': '/content/drive/MyDrive/CDD-CESM-curated-dataset/images/P176_R_CM_CC.jpg',
 'image_r_mlo_dm': '/content/drive/MyDrive/CDD-CESM-curated-dataset/images/P176_R_DM_MLO.jpg',
 'image_r_mlo_cm': '/content/drive/MyDrive/CDD-CESM-curated-dataset/images/P176_R_CM_MLO.jpg',
 'report_l_dm': 'ACR C: Heterogenously dense breasts.\n\nFindings: \nNo speculated mass lesions or suspicious microcalcifications.\nNormal skin thickness and contour of breast.\n\nOpinion: \nNormal breast examination (BIRADS 1

In [None]:
# Convert to list of dictionaries to write into jsonl files and save it
datasets = list(patient_data.values())
output_file = "medgemma_contrastive_dataset.jsonl"

with open(output_file, 'w', encoding='utf-8') as f:
  for json_line in datasets:
    f.write(json.dumps(json_line) + '\n')

!cp /content/medgemma_contrastive_dataset.jsonl /content/drive/MyDrive/medgemma_contrastive_dataset.jsonl

# Load Data from Drive

In [None]:
train_size = 0.9  # @param {type: "number"}
validation_size = 0.1  # @param {type: "number"}

dataset = load_dataset("json", data_files="/content/drive/MyDrive/medgemma_contrastive_dataset.jsonl", split="train")
dataset = dataset.cast_column("image_l_cc_dm", Image())
dataset = dataset.cast_column("image_l_cc_cm", Image())
dataset = dataset.cast_column("image_l_mlo_dm", Image())
dataset = dataset.cast_column("image_l_mlo_cm", Image())
dataset = dataset.cast_column("image_r_cc_dm", Image())
dataset = dataset.cast_column("image_r_cc_cm", Image())
dataset = dataset.cast_column("image_r_mlo_dm", Image())
dataset = dataset.cast_column("image_r_mlo_cm", Image())


Generating train split: 0 examples [00:00, ? examples/s]

## Process to jsonl file



This step is to prepare for custom training data loop in the contrastive learning step.   
Skip the next cell and use `dataset` if using it in HuggingFace Trainer loop


In [None]:
train_size = 0.9  # @param {type: "number"}
validation_size = 0.1  # @param {type: "number"}

dataset = load_dataset("json", data_files="/content/drive/MyDrive/medgemma_contrastive_dataset.jsonl", split="train")

dataset_l = (dataset
    .select_columns(["image_l_cc_dm", "image_l_mlo_dm", "report_l_dm"])
    .rename_column("image_l_cc_dm", "cc_path")
    .rename_column("image_l_mlo_dm", "mlo_path")
    .rename_column("report_l_dm", "report"))

dataset_r = (dataset
    .select_columns(["image_r_cc_dm", "image_r_mlo_dm", "report_r_dm"])
    .rename_column("image_r_cc_dm", "cc_path")
    .rename_column("image_r_mlo_dm", "mlo_path")
    .rename_column("report_r_dm", "report"))

combined_dataset = concatenate_datasets([dataset_l, dataset_r])
filtered_dataset = combined_dataset.filter(
    lambda x: x['cc_path'] and
              x['mlo_path'] and
              x['report']
)

data = filtered_dataset.train_test_split(
    train_size=train_size,
    test_size=validation_size,
    shuffle=True,
    seed=42,
)

data["validation"] = data.pop("test") # Change test into validation
print(f"Train split size: {len(data['train'])}\nValidation split size: {len(data['validation'])}")
data

Filter:   0%|          | 0/652 [00:00<?, ? examples/s]

Train split size: 391
Validation split size: 44


DatasetDict({
    train: Dataset({
        features: ['cc_path', 'mlo_path', 'report'],
        num_rows: 391
    })
    validation: Dataset({
        features: ['cc_path', 'mlo_path', 'report'],
        num_rows: 44
    })
})

In [None]:
data["train"].to_json("train_dataset.jsonl", orient="records", lines=True)
data["validation"].to_json("validation_dataset.jsonl", orient="records", lines=True)

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

20152

# Training setup

## Config

In [None]:
class Config:
    model_id = "google/medgemma-4b-it"
    train_file = "train_dataset.jsonl"
    validation_file = "validation_dataset.jsonl"
    output_dir = "./mammo_clip_checkpoints"

    # Hyperparameters
    max_length = 128          # Text length for reports
    batch_size = 32           # Keep small per GPU
    grad_accumulation = 1     # Useless, no need to use
    num_epochs = 50           # Recommend 25~50 Maybe stop at 10 to prevent overfitting
    learning_rate = 2e-4      # LoRA needs higher learning rate (other papers use 5e-5 without LoRA)
    weight_decay = 0.05 #5e-3       # Smaller dataset needs higher weight decay to prevent overfit (Other paper uses 1e-4, some highly similar paper use 0.05)
    temperature = 0.07        # Learnable or fixed for InfoNCE loss
    projection_dim = 512      # Dimension to project images/text into
    val_split_pct = 0.1       # 10% for validation
    num_workers = 2           # Simple parameter to speed up data loading
    lora_r = 8                # LoRA r parameter
    lora_alpha = 16           # LoRA alpha parameter
    lora_dropout = 0.40       # LoRA dropout parameter

## Dataset

In [None]:
class MammoDataset(Dataset):
    def __init__(self, jsonl_file, processor, split="test"): # Default is test split to bypass data augmentation pipeline
        self.data = []
        with open(jsonl_file, 'r') as f:
            for line in f:
                self.data.append(json.loads(line))
        self.processor = processor

        # Data augmentation pipeline for training
        self.split = split
        self.transform = T.Compose([
                # 1. Random Resized Crop:
                T.RandomResizedCrop(size=(896, 896), scale=(0.7, 0.9), ratio=(0.45, 0.65)),

                # 2. Intensity/Contrast Jitter
                T.ColorJitter(brightness=0.2, contrast=0.2),

                # 3. Random Gaussian Blur
                T.RandomApply([T.GaussianBlur(kernel_size=19, sigma=(0.1, 2.0))], p=0.5)
            ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Load Images
        image_cc = PILImage.open(item['cc_path']).convert("RGB")
        image_mlo = PILImage.open(item['mlo_path']).convert("RGB")

        if self.split == "train":
            image_cc = self.transform(image_cc)
            image_mlo = self.transform(image_mlo)

        text = item['report']

        # Process Inputs using the VLM processor
        # We process CC and MLO separately but using the same pipeline
        inputs_cc = self.processor.image_processor(images=image_cc, return_tensors="pt")
        inputs_mlo = self.processor.image_processor(images=image_mlo, return_tensors="pt")
        report = self.processor.tokenizer(text, return_tensors="pt", padding="max_length", max_length=Config.max_length, truncation=True)


        return {
            "pixel_values_cc": inputs_cc["pixel_values"].squeeze(0),
            "pixel_values_mlo": inputs_mlo["pixel_values"].squeeze(0),
            "input_ids": report["input_ids"].squeeze(0), # Text is same for both
            "attention_mask": report["attention_mask"].squeeze(0)
        }

## Model Architecture

In [None]:
class MammoContrastiveModel(nn.Module):
    def __init__(self, base_model, hidden_size, projection_dim=512):
        super().__init__()
        self.base_model = base_model

        # Contrastive Projection Heads
        # We project the large VLM embeddings down to a smaller "CLIP space"
        self.visual_projection = nn.Linear(1152, projection_dim).to(dtype=torch.bfloat16) # SigLIP embedding size (usually 1152)
        self.text_projection = nn.Linear(hidden_size, projection_dim).to(dtype=torch.bfloat16) # LLM hidden size
        self.logit_scale = nn.Parameter(torch.ones([]) * 2.6592) # ln(1/0.07)

    def get_image_features(self, pixel_values):
        # Extract features from the Vision Tower (SigLIP) directly
        # Note: Accessing .vision_tower depends on specific HF implementation of MedGemma/PaliGemma
        vision_outputs = self.base_model.vision_tower(pixel_values)

        # SigLIP usually outputs (Batch, Num_Patches, Dim)
        # We pool these features. Global Average Pooling is standard for CLIP.
        pooled_vision = vision_outputs.last_hidden_state.mean(dim=1)

        # Cast to the projection layer's dtype
        return self.visual_projection(pooled_vision.to(self.visual_projection.weight.dtype))

    def get_text_features(self, input_ids, attention_mask):
        # Extract features from the LLM (Gemma)
        # We pass text through the LLM and take the LAST token (EOS) as the representation
        outputs = self.base_model.language_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        last_hidden_state = outputs.hidden_states[-1]
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_state.shape[0]
        pooled_text = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]

        # Cast to the projection layer's dtype
        return self.text_projection(pooled_text.to(self.text_projection.weight.dtype))

    def forward(self, pixel_values_cc, pixel_values_mlo, input_ids, attention_mask):
        # 1. Get Embeddings
        z_cc = self.get_image_features(pixel_values_cc)
        z_mlo = self.get_image_features(pixel_values_mlo)
        z_text = self.get_text_features(input_ids, attention_mask)

        # 2. Normalize
        z_cc = F.normalize(z_cc, p=2, dim=1)
        z_mlo = F.normalize(z_mlo, p=2, dim=1)
        z_text = F.normalize(z_text, p=2, dim=1)

        return z_cc, z_mlo, z_text

## Proposed Contrastive Loss Function

In [None]:
def contrastive_loss(z1, z2, logit_scale):
    """
    Standard CLIP Loss: Symmetric Cross Entropy
    """
    batch_size = z1.size(0)

    # Cosine similarity
    cosine_similarity = torch.matmul(z1, z2.t())

    logits_per_1 = cosine_similarity * logit_scale.exp()
    logits_per_2 = logits_per_1.t()

    labels = torch.arange(batch_size, device=z1.device)

    loss_1 = F.cross_entropy(logits_per_1, labels)
    loss_2 = F.cross_entropy(logits_per_2, labels)

    loss = (loss_1 + loss_2) / 2 # For training / optimization objectives
    total_consine_similarity = cosine_similarity.diag().sum() # Only for reporting purposes

    return loss, total_consine_similarity

## Model Initialization

In [None]:
# A. Load Model in 4-bit (QLoRA)
model_kwargs = dict(
    attn_implementation="sdpa", # "sdpa", "eager"
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

gemma_model = AutoModelForImageTextToText.from_pretrained(Config.model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained(Config.model_id)

# Use right padding to avoid issues during training
processor.tokenizer.padding_side = "right"


# Disable cache for gradient checkpointing
if hasattr(gemma_model.config, "use_cache"):
    gemma_model.config.use_cache = False

# Enable Gradient Checkpointing (save memory and allow larger batch size)
gemma_model.gradient_checkpointing_enable()

# Fix for "None of the inputs have requires_grad=True" warning
# This ensures gradients flow through, so checkpointing works
gemma_model.enable_input_require_grads()

config.json:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [None]:
# B. Apply LoRA (Crucial for training 4B model on limited hardware)
# We target both vision and language modules
peft_config = LoraConfig(
    r=Config.lora_r,
    lora_alpha=Config.lora_alpha,
    lora_dropout=Config.lora_dropout,
    bias="none",
    task_type="FEATURE_EXTRACTION", # Custom task effectively

    #target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    target_modules="all-linear",
    # modules_to_save=[
    #     "lm_head",
    #     "embed_tokens",
    # ],

)
base_model = get_peft_model(gemma_model, peft_config)
base_model.print_trainable_parameters()

# C. Wrap in our Contrastive Module
# Note: We cast base model to float32 for stable contrastive training if memory allows,
# or keep in mixed precision.
model = MammoContrastiveModel(base_model, hidden_size=base_model.config.text_config.hidden_size).cuda()

trainable params: 19,248,896 || all params: 4,319,328,368 || trainable%: 0.4456


## Optimizer and DataLoader

In [None]:
# D. Optimizer
# We only train the LoRA adapters and the new Projection Heads
optimizer = torch.optim.AdamW(model.parameters(), lr=Config.learning_rate, weight_decay=Config.weight_decay)

# E. Data Loader
dataset = MammoDataset(Config.train_file, processor, split="train")
dataloader = DataLoader(dataset, batch_size=Config.batch_size, shuffle=True, num_workers=Config.num_workers)

validation_dataset = MammoDataset(Config.validation_file, processor, split="test")
validation_dataloader = DataLoader(validation_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=Config.num_workers)

# F. Scheduler
num_update_steps_per_epoch = len(dataloader) // Config.grad_accumulation
max_train_steps = Config.num_epochs * num_update_steps_per_epoch

lr_scheduler = get_constant_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=int(0.1 * max_train_steps), # 10% Warmup
    # num_training_steps=max_train_steps
)

## Validation function

In [None]:
# Recall at K (R@k) metrics
def compute_recall_k(z_img, z_text, k_vals=[1,2,3]):
    """
    Computes R@1, R@5, R@10 given normalized image and text embeddings.
    """
    # 1. Calculate similarity matrix (Dot product of normalized vectors = Cosine Similarity)
    # Shape: [N_samples, N_samples]
    logits = torch.matmul(z_img, z_text.t())

    # 2. Ground truth: The i-th image matches the i-th text (diagonal)
    batch_size = logits.shape[0]
    labels = torch.arange(batch_size, device=logits.device)

    results = {}

    # 3. Find top K matches
    max_k = max(k_vals)
    _, top_indices = logits.topk(max_k, dim=1)

    for k in k_vals:
        # Check if the correct label is within the top k predictions
        is_correct = top_indices[:, :k].eq(labels.unsqueeze(1)).any(dim=1)
        results[f"R@{k}"] = is_correct.float().mean().item() * 100

    return results

## WandB Logger Initialization

In [None]:
from google.colab import userdata
os.environ["WANDB_API_KEY"] = userdata.get('WANDB_API_KEY')

config_dict = {k: v for k, v in vars(Config).items() if not k.startswith('_')}
config_dict

logger = wandb.init(
    entity="wjlingz-none",
    project="mammogram-contrastive-learning-project",
    name="contrastive-training",
    config=config_dict
)

# Use `logger.log({"acc": 0.1, "loss": 0.2})` to log

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Currently logged in as: [33mwjlingz[0m ([33mwjlingz-none[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Training Loop

In [None]:
def train():

    print("Starting Contrastive Pre-Training...")

    best_eval_loss = float('inf') # Used for early stopping, use high loss as a starting point
    best_eval_loss_epoch = 0 # Early stopping if loss is not improving for 3 consecutive epochs

    for epoch in range(Config.num_epochs):

        optimizer.zero_grad()
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
        num_steps = len(dataloader)
        training_data_size = len(dataset)
        validation_data_size = len(validation_dataset)

        # ========== Training loop ==========
        model.train()     # Switch to training mode
        epoch_loss = 0      # For training performance evaluation
        all_training_z_cc = []
        all_training_z_mlo = []
        all_training_z_text = []
        cc_text_total_cos_sim = 0
        mlo_text_total_cos_sim = 0
        cc_mlo_total_cos_sim = 0

        for step, batch in enumerate(progress_bar):
            # Move to GPU
            # Use bfloat16 to match the model weights
            pv_cc = batch["pixel_values_cc"].cuda().to(torch.bfloat16).requires_grad_(True)
            pv_mlo = batch["pixel_values_mlo"].cuda().to(torch.bfloat16).requires_grad_(True)
            ids = batch["input_ids"].cuda()
            mask = batch["attention_mask"].cuda()

            # Forward
            z_cc, z_mlo, z_text = model(pv_cc, pv_mlo, ids, mask)

            # Calculate Losses
            loss_cc_text, cc_text_summed_cos_sim = contrastive_loss(z_cc, z_text, model.logit_scale)
            loss_mlo_text, mlo_text_summed_cos_sim = contrastive_loss(z_mlo, z_text, model.logit_scale)
            loss_img_img, cc_mlo_summed_cos_sim = contrastive_loss(z_cc, z_mlo, model.logit_scale)

            # Recording the output to avoid recomputation for evaluation purpose
            all_training_z_cc.append(z_cc.float().cpu())
            all_training_z_mlo.append(z_mlo.float().cpu())
            all_training_z_text.append(z_text.float().cpu())
            cc_text_total_cos_sim += cc_text_summed_cos_sim.item()
            mlo_text_total_cos_sim += mlo_text_summed_cos_sim.item()
            cc_mlo_total_cos_sim += cc_mlo_summed_cos_sim.item()

            # Weighted Sum (Can tune these weights)
            # total_loss = (loss_cc_text + loss_mlo_text + loss_img_img) / 3 # Original MVS-CLIP loss setup
            total_loss = (loss_cc_text + loss_mlo_text) / 2 # Removed MVS loss for ablation study

            total_loss.backward()

            # Back propagation
            optimizer.step()
            lr_scheduler.step() # Update learning rate
            optimizer.zero_grad()

            epoch_loss += total_loss.item()
            progress_bar.set_postfix({"loss": epoch_loss / (step + 1)})

        # Compute performance metrics for logging
        cc_text_cos_sim = cc_text_total_cos_sim / training_data_size
        mlo_text_cos_sim = mlo_text_total_cos_sim / training_data_size
        cc_mlo_cos_sim = cc_mlo_total_cos_sim / training_data_size

        z_cc_training = torch.cat(all_training_z_cc, dim=0)
        z_mlo_training = torch.cat(all_training_z_mlo, dim=0)
        z_text_training = torch.cat(all_training_z_text, dim=0)

        cc_text_training_metrics = compute_recall_k(z_cc_training, z_text_training)
        mlo_text_training_metrics = compute_recall_k(z_mlo_training, z_text_training)
        cc_mlo_training_metrics = compute_recall_k(z_cc_training, z_mlo_training)
        # ========== Training loop ==========


        # ========== Evaluation loop ==========
        # Evaluation in terms of loss, similarity metrics, R@1,2,3 performance
        model.eval()    # Switch to evaluation mode

        evaluation_loss = 0
        all_evaluation_z_cc = []
        all_evaluation_z_mlo = []
        all_evaluation_z_text = []
        evaluation_cc_text_total_cos_sim = 0
        evaluation_mlo_text_total_cos_sim = 0
        evaluation_cc_mlo_total_cos_sim = 0

        with torch.no_grad():
            for batch in validation_dataloader:
                # Move inputs to GPU
                pv_cc = batch["pixel_values_cc"].cuda().to(torch.bfloat16)
                pv_mlo = batch["pixel_values_mlo"].cuda().to(torch.bfloat16)
                ids = batch["input_ids"].cuda()
                mask = batch["attention_mask"].cuda()

                # Forward
                z_cc, z_mlo, z_text = model(pv_cc, pv_mlo, ids, mask)

                # Calculate Losses
                loss_cc_text, cc_text_summed_cos_sim = contrastive_loss(z_cc, z_text, model.logit_scale)
                loss_mlo_text, mlo_text_summed_cos_sim = contrastive_loss(z_mlo, z_text, model.logit_scale)
                loss_img_img, cc_mlo_summed_cos_sim = contrastive_loss(z_cc, z_mlo, model.logit_scale)
                total_loss = (loss_cc_text + loss_mlo_text + loss_img_img) / 3

                # Recording the output to avoid recomputation for evaluation purpose
                all_evaluation_z_cc.append(z_cc.float().cpu())
                all_evaluation_z_mlo.append(z_mlo.float().cpu())
                all_evaluation_z_text.append(z_text.float().cpu())
                evaluation_cc_text_total_cos_sim += cc_text_summed_cos_sim.item()
                evaluation_mlo_text_total_cos_sim += mlo_text_summed_cos_sim.item()
                evaluation_cc_mlo_total_cos_sim += cc_mlo_summed_cos_sim.item()

                evaluation_loss += total_loss.item()

        # Compute performance metrics for logging
        evaluation_cc_text_cos_sim = evaluation_cc_text_total_cos_sim / validation_data_size
        evaluation_mlo_text_cos_sim = evaluation_mlo_text_total_cos_sim / validation_data_size
        evaluation_cc_mlo_cos_sim = evaluation_cc_mlo_total_cos_sim / validation_data_size

        z_cc_evaluation = torch.cat(all_evaluation_z_cc, dim=0)
        z_mlo_evaluation = torch.cat(all_evaluation_z_mlo, dim=0)
        z_text_evaluation = torch.cat(all_evaluation_z_text, dim=0)

        cc_text_evaluation_metrics = compute_recall_k(z_cc_evaluation, z_text_evaluation)
        mlo_text_evaluation_metrics = compute_recall_k(z_mlo_evaluation, z_text_evaluation)
        cc_mlo_evaluation_metrics = compute_recall_k(z_cc_evaluation, z_mlo_evaluation)

        avg_eval_loss = evaluation_loss / len(validation_dataloader)
        # ========== Evaluation loop ==========


        # ========== Logging Step ==========
        logger.log({"training/epoch_time_taken": progress_bar.format_dict['elapsed'],
                    "training/logit_scale": model.logit_scale.exp().item(),
                    "training/epoch_loss": epoch_loss / num_steps,
                    "training/cc_text_sim":cc_text_cos_sim,
                    "training/mlo_text_sim":mlo_text_cos_sim,
                    "training/cc_mlo_sim":cc_mlo_cos_sim,
                    **{f"training/cc-text-{k}": v for k, v in cc_text_training_metrics.items()},
                    **{f"training/mlo-text-{k}": v for k, v in mlo_text_training_metrics.items()},
                    **{f"training/cc-mlo-{k}": v for k, v in cc_mlo_training_metrics.items()},
                    "evaluation/epoch_loss": avg_eval_loss,
                    "evaluation/cc_text_sim": evaluation_cc_text_cos_sim,
                    "evaluation/mlo_text_sim": evaluation_mlo_text_cos_sim,
                    "evaluation/cc_mlo_sim": evaluation_cc_mlo_cos_sim,
                    **{f"evaluation/cc-text-{k}": v for k, v in cc_text_evaluation_metrics.items()},
                    **{f"evaluation/mlo-text-{k}": v for k, v in mlo_text_evaluation_metrics.items()},
                    **{f"evaluation/cc-mlo-{k}": v for k, v in cc_mlo_evaluation_metrics.items()},
                    })
        # ========== Logging Step ==========


        # ========== Saving or Early Stopping ==========
        if avg_eval_loss < best_eval_loss: # Best evaluation performance found
            best_eval_loss = avg_eval_loss
            best_eval_loss_epoch = epoch+1

            # Save best model
            save_path = os.path.join(Config.output_dir, "checkpoint")
            os.makedirs(save_path, exist_ok=True)
            # Save LoRA
            model.base_model.save_pretrained(save_path)
            # Save Projection Heads manually
            torch.save(model.visual_projection.state_dict(), os.path.join(save_path, "visual_proj.pt"))
            torch.save(model.text_projection.state_dict(), os.path.join(save_path, "text_proj.pt"))
            print(f"Saved checkpoint to {save_path}")

        elif epoch+1 - best_eval_loss_epoch >= 3: # Early stopping if 3 consecutive epoch did not improve best eval loss
            print(f"Early stopping at epoch {epoch+1} due to no improvement in evaluation loss")
            break
        # ========== Saving or Early Stopping ==========

    logger.finish()


In [None]:
# Result logged here
# https://wandb.ai/wjlingz-none/mammogram-contrastive-learning-project/

## Run the training loop

In [None]:
train()

Starting Contrastive Pre-Training...


Epoch 1:   0%|          | 0/13 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Epoch 1: 100%|██████████| 13/13 [07:28<00:00, 34.51s/it, loss=3.36]


Saved checkpoint to ./mammo_clip_checkpoints/checkpoint


Epoch 2: 100%|██████████| 13/13 [04:10<00:00, 19.25s/it, loss=3.34]


Saved checkpoint to ./mammo_clip_checkpoints/checkpoint


Epoch 3: 100%|██████████| 13/13 [04:00<00:00, 18.50s/it, loss=3.26]


Saved checkpoint to ./mammo_clip_checkpoints/checkpoint


Epoch 4: 100%|██████████| 13/13 [04:06<00:00, 18.93s/it, loss=2.95]


Saved checkpoint to ./mammo_clip_checkpoints/checkpoint


Epoch 5: 100%|██████████| 13/13 [04:06<00:00, 18.93s/it, loss=2.66]


Saved checkpoint to ./mammo_clip_checkpoints/checkpoint


Epoch 6: 100%|██████████| 13/13 [04:04<00:00, 18.80s/it, loss=2.35]


Saved checkpoint to ./mammo_clip_checkpoints/checkpoint


Epoch 7: 100%|██████████| 13/13 [04:08<00:00, 19.11s/it, loss=2.14]


Saved checkpoint to ./mammo_clip_checkpoints/checkpoint


Epoch 8: 100%|██████████| 13/13 [04:07<00:00, 19.06s/it, loss=1.81]


Saved checkpoint to ./mammo_clip_checkpoints/checkpoint


Epoch 9: 100%|██████████| 13/13 [04:07<00:00, 19.06s/it, loss=1.57]


Saved checkpoint to ./mammo_clip_checkpoints/checkpoint


Epoch 10: 100%|██████████| 13/13 [04:00<00:00, 18.49s/it, loss=1.34]
Epoch 11: 100%|██████████| 13/13 [04:03<00:00, 18.75s/it, loss=1.16]
Epoch 12: 100%|██████████| 13/13 [04:02<00:00, 18.62s/it, loss=1.1]


Early stopping at epoch 12 due to no improvement in evaluation loss


0,1
evaluation/cc-mlo-R@1,▁▁▃▅▅▇██████
evaluation/cc-mlo-R@2,▁▁▃▆▆▆▇█████
evaluation/cc-mlo-R@3,▁▁▃▇▆▇▇▇████
evaluation/cc-text-R@1,▁▂▃▅▆▅█▆▆▇█▇
evaluation/cc-text-R@2,▁▁▂▅▅▄▇▅▇▇▇█
evaluation/cc-text-R@3,▁▁▂▅▅▃▆▅▇▇▆█
evaluation/cc_mlo_sim,███▇▆▅▄▃▂▂▁▁
evaluation/cc_text_sim,▁▁▁▅▇▇▇█▇▇▇▆
evaluation/epoch_loss,██▇▅▄▄▂▁▁▁▁▁
evaluation/mlo-text-R@1,▁▂▁▄▅▄▆▆▇▆▆█

0,1
evaluation/cc-mlo-R@1,75
evaluation/cc-mlo-R@2,84.09091
evaluation/cc-mlo-R@3,86.36364
evaluation/cc-text-R@1,25
evaluation/cc-text-R@2,38.63636
evaluation/cc-text-R@3,52.27273
evaluation/cc_mlo_sim,0.69318
evaluation/cc_text_sim,0.31889
evaluation/epoch_loss,1.50586
evaluation/mlo-text-R@1,31.81818


# Save Model

In [None]:
# Save to drive (Manually check what epoch is the latest)
!cp -r /content/mammo_clip_checkpoints/checkpoint /content/drive/MyDrive/


In [None]:
from google.colab import runtime
import time

print("⏳ Waiting 30s for background tasks...")
time.sleep(30)

print("👋 Disconnecting runtime.")
runtime.unassign()

⏳ Waiting 30s for background tasks...
👋 Disconnecting runtime.


## Reload back the model from drive then upload to HuggingFace Hub

### Require disconnect and reconnect run time if not enough RAM
Remember to run the cell for Drive Mounting, Install and Import library, load dataset, Dataset definition, Model Architecture definition, Loss Function, Validation Function.  

To verify validation performance, check appendix


In [None]:
# import gc

# del model
# torch.cuda.empty_cache()
# gc.collect()

11132

In [None]:
checkpoint_dir = "/content/drive/MyDrive/checkpoint"

In [None]:
from peft import PeftModel, PeftConfig
import os

# 1. Load Base Model (Dequantized) + Processor
base_model_id = "google/medgemma-4b-it"

model_kwargs = dict(
    attn_implementation="sdpa", # "sdpa", "eager"
    torch_dtype=torch.bfloat16,
    device_map="auto",
)


dequantized_gemma = AutoModelForImageTextToText.from_pretrained(base_model_id, **model_kwargs)

# 2. Load LoRA adapters from your hub repo
base_model = PeftModel.from_pretrained(dequantized_gemma, "/content/drive/MyDrive/checkpoint")
merged_base_model = base_model.merge_and_unload()

processor = AutoProcessor.from_pretrained(base_model_id)
processor.tokenizer.padding_side = "right"

## Upload to Hub
# Save to HuggingFace
from huggingface_hub import login
from google.colab import userdata
from huggingface_hub import HfApi
from huggingface_hub import hf_hub_download

# 3. Login
login(token=userdata.get('HF_TOKEN'))
# FILL IN THE PROPER REPO_ID, COMMENTED OUT TO AVOID ACCIDENT
# "weijietling/medgemma-4b-it-contrastive-trained-130126" this model is the default contrastive training setup, early stopping at epoch 10, other model is for ablation study
# "weijietling/medgemma-4b-it-contrastive-trained-150126-mvs-ablation" this model use default clip contrastive setup without mvs loss (no img-img-loss)
repo_id = "weijietling/medgemma-4b-it-contrastive-trained-150126-mvs-ablation"

# 4. Push the merged model (called contrastive model in the future)
merged_base_model.push_to_hub(repo_id) # Takes ~10 minutes

# 5. Push the tokenizer/processor so you can use it later
processor.push_to_hub(repo_id)

# 6. Upload them to the Hub
api = HfApi()

drive_visual_path = os.path.join(checkpoint_dir, "visual_proj.pt")
drive_text_path = os.path.join(checkpoint_dir, "text_proj.pt")

# Upload Visual Projection
api.upload_file(
    path_or_fileobj=drive_visual_path,    # Source: Your Drive file
    path_in_repo="visual_projection.pt",  # Destination: Name in HF Hub
    repo_id=repo_id,
    repo_type="model"
)

# Upload Text Projection
api.upload_file(
    path_or_fileobj=drive_text_path,      # Source: Your Drive file
    path_in_repo="text_projection.pt",    # Destination: Name in HF Hub
    repo_id=repo_id,
    repo_type="model"
)

print(f"Successfully pushed all files to https://huggingface.co/{repo_id}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...0001-of-00002.safetensors:   1%|          | 41.9MB / 4.96GB            

  ...0002-of-00002.safetensors:   1%|1         | 41.9MB / 3.64GB            

No files have been modified since last commit. Skipping to prevent empty commit.


Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...pc007l_zd/tokenizer.model: 100%|##########| 4.69MB / 4.69MB            

  ...mpc007l_zd/tokenizer.json: 100%|##########| 33.4MB / 33.4MB            

No files have been modified since last commit. Skipping to prevent empty commit.


Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...checkpoint/visual_proj.pt: 100%|##########| 1.18MB / 1.18MB            

No files have been modified since last commit. Skipping to prevent empty commit.


Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...e/checkpoint/text_proj.pt: 100%|##########| 2.62MB / 2.62MB            

No files have been modified since last commit. Skipping to prevent empty commit.


Successfully pushed all files to https://huggingface.co/weijietling/medgemma-4b-it-contrastive-trained-150126-mvs-ablation


### Comparison:
1. Medgemma-4b-it + few-shot training
2. Medgemma-4b-it + contrastive training + few-shot training
3. Medgemma-4b-it + image-report pretraining
4. Medgemma-4b-it + contrastive training + image-report pretraining


# Appendix

## Load model


### Quantized Gemma + LoRA Adapter

In [None]:
# Test for model weight loading
## Ensure the performance on validation set is same as during training time (validation set)

from peft import PeftModel, PeftConfig

# 1. Setup Config
class Config:
    model_id = "google/medgemma-4b-it"
    validation_file = "validation_dataset.jsonl"
    num_workers = 2
    max_length = 128

# 2. Load Base Model (QLoRA)
model_kwargs = dict(
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Skip quantization if we are going to merge the LoRA adapter
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

gemma_model = AutoModelForImageTextToText.from_pretrained(Config.model_id, **model_kwargs)

# 3. Load Processor and fix padding
processor = AutoProcessor.from_pretrained(Config.model_id)
processor.tokenizer.padding_side = "right"

# 4. Load LoRA Adapters
# Point to your saved checkpoint folder
checkpoint_path = "/content/drive/MyDrive/checkpoint-epoch-10"
base_model = PeftModel.from_pretrained(gemma_model, checkpoint_path)

# 5. Initialize Wrapper
# Note: We must ensure the wrapper is on the correct device/dtype
model = MammoContrastiveModel(base_model, hidden_size=base_model.config.text_config.hidden_size).cuda()

# 6. Load Projection Heads
visual_proj_path = os.path.join(checkpoint_path, "visual_proj.pt")
text_proj_path = os.path.join(checkpoint_path, "text_proj.pt")

model.visual_projection.load_state_dict(torch.load(visual_proj_path))
model.text_projection.load_state_dict(torch.load(text_proj_path))

# 7. Manually Restore Logit Scale (Optional but recommended for Loss consistency)
# Since you didn't save it, check your WandB logs for the value of 'training/logit_scale'
# at epoch 10. Let's say it was 4.2.
# model.logit_scale.data = torch.tensor(4.2).float().cuda()

print("Model fully restored!")

config.json:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Model fully restored!


### Dequantized Gemma Merged with LoRA Adapter
(Quantized Gemma Merged with LoRA gives sucky performance)

In [None]:
# Test for model weight loading
## Ensure the performance on validation set is same as during training time (validation set)

from peft import PeftModel, PeftConfig

# 1. Setup Config
class Config:
    model_id = "google/medgemma-4b-it"
    validation_file = "validation_dataset.jsonl"
    num_workers = 2
    max_length = 128

# 2. Load Base Model (QLoRA)
model_kwargs = dict(
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Skip quantization if we are going to merge the LoRA adapter
# model_kwargs["quantization_config"] = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
#     bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
# )

gemma_model = AutoModelForImageTextToText.from_pretrained(Config.model_id, **model_kwargs)

# 3. Load Processor and fix padding
processor = AutoProcessor.from_pretrained(Config.model_id)
processor.tokenizer.padding_side = "right"

# 4. Load LoRA Adapters
# Point to your saved checkpoint folder
checkpoint_path = "/content/drive/MyDrive/checkpoint-epoch-10"
base_model = PeftModel.from_pretrained(gemma_model, checkpoint_path)

# 5. Initialize Wrapper
# Note: We must ensure the wrapper is on the correct device/dtype
merged_base_model = base_model.merge_and_unload()
model = MammoContrastiveModel(merged_base_model, hidden_size=base_model.config.text_config.hidden_size).cuda()

# 6. Load Projection Heads
visual_proj_path = os.path.join(checkpoint_path, "visual_proj.pt")
text_proj_path = os.path.join(checkpoint_path, "text_proj.pt")

model.visual_projection.load_state_dict(torch.load(visual_proj_path))
model.text_projection.load_state_dict(torch.load(text_proj_path))

# 7. Manually Restore Logit Scale (Optional but recommended for Loss consistency)
# Since you didn't save it, check your WandB logs for the value of 'training/logit_scale'
# at epoch 10. Let's say it was 4.2.
# model.logit_scale.data = torch.tensor(4.2).float().cuda()

print("Model fully restored!")

### Quantized Contrastive Model (HuggingFace Hub)
(Contrastive Model = Dequantized Gemma Merged with LoRA)

In [None]:
# Test for model weight loading THIS TIME LOAD THE MERGED MODEL STORED IN HUGGINGFACE HUB
## Ensure the performance on validation set is same as during training time (validation set)

from peft import PeftModel, PeftConfig
from huggingface_hub import login
from google.colab import userdata
from huggingface_hub import HfApi
from huggingface_hub import hf_hub_download

# Replace with your WRITE token
login(token=userdata.get('HF_TOKEN'))

# 1. Setup Config
class Config:
    model_id = "weijietling/medgemma-4b-it-contrastive-trained-150126-mvs-ablation" # "weijietling/medgemma-4b-it-contrastive-trained-130126"
    validation_file = "validation_dataset.jsonl"
    num_workers = 2
    max_length = 128

# 2. Load Base Model (QLoRA)
model_kwargs = dict(
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Skip quantization if we are going to merge the LoRA adapter
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

contrastive_model = AutoModelForImageTextToText.from_pretrained(Config.model_id, **model_kwargs)

# 3. Load Processor and fix padding
processor = AutoProcessor.from_pretrained(Config.model_id)
processor.tokenizer.padding_side = "right"

# 4. Load LoRA Adapters
# Point to your saved checkpoint folder
# checkpoint_path = "/content/drive/MyDrive/checkpoint-epoch-10"
# base_model = PeftModel.from_pretrained(gemma_model, checkpoint_path)

# 5. Initialize Wrapper
# Note: We must ensure the wrapper is on the correct device/dtype
# model = MammoContrastiveModel(base_model, hidden_size=base_model.config.text_config.hidden_size).cuda()

model = MammoContrastiveModel(contrastive_model, hidden_size=contrastive_model.config.text_config.hidden_size).cuda()

# 6. Load Projection Heads
visual_proj_path = hf_hub_download(repo_id=Config.model_id, filename="visual_projection.pt", repo_type="model")
text_proj_path = hf_hub_download(repo_id=Config.model_id, filename="text_projection.pt", repo_type="model")

model.visual_projection.load_state_dict(torch.load(visual_proj_path))
model.text_projection.load_state_dict(torch.load(text_proj_path))

# 7. Manually Restore Logit Scale (Optional but recommended for Loss consistency)
# Since you didn't save it, check your WandB logs for the value of 'training/logit_scale'
# at epoch 10. Let's say it was 4.2.
# model.logit_scale.data = torch.tensor(4.2).float().cuda()

print("Model fully restored!")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

visual_projection.pt:   0%|          | 0.00/1.18M [00:00<?, ?B/s]

text_projection.pt:   0%|          | 0.00/2.62M [00:00<?, ?B/s]

Model fully restored!


### Evaluation on validation

In [None]:
# Evaluation on the whole validation set
validation_dataset = MammoDataset(Config.validation_file, processor, split="test")
validation_dataloader = DataLoader(validation_dataset, batch_size=4, shuffle=False, num_workers=Config.num_workers)

model.eval()    # Switch to evaluation mode

evaluation_loss = 0
all_evaluation_z_cc = []
all_evaluation_z_mlo = []
all_evaluation_z_text = []
evaluation_cc_text_total_cos_sim = 0
evaluation_mlo_text_total_cos_sim = 0
evaluation_cc_mlo_total_cos_sim = 0
validation_data_size = len(validation_dataset)

with torch.no_grad():
    for batch in validation_dataloader:
        # Move inputs to GPU
        pv_cc = batch["pixel_values_cc"].cuda().to(torch.bfloat16)
        pv_mlo = batch["pixel_values_mlo"].cuda().to(torch.bfloat16)
        ids = batch["input_ids"].cuda()
        mask = batch["attention_mask"].cuda()

        # Forward
        z_cc, z_mlo, z_text = model(pv_cc, pv_mlo, ids, mask)

        # Calculate Losses
        loss_cc_text, cc_text_summed_cos_sim = contrastive_loss(z_cc, z_text, model.logit_scale)
        loss_mlo_text, mlo_text_summed_cos_sim = contrastive_loss(z_mlo, z_text, model.logit_scale)
        loss_img_img, cc_mlo_summed_cos_sim = contrastive_loss(z_cc, z_mlo, model.logit_scale)
        total_loss = (loss_cc_text + loss_mlo_text + loss_img_img) / 3

        # Recording the output to avoid recomputation for evaluation purpose
        all_evaluation_z_cc.append(z_cc.float().cpu())
        all_evaluation_z_mlo.append(z_mlo.float().cpu())
        all_evaluation_z_text.append(z_text.float().cpu())
        evaluation_cc_text_total_cos_sim += cc_text_summed_cos_sim.item()
        evaluation_mlo_text_total_cos_sim += mlo_text_summed_cos_sim.item()
        evaluation_cc_mlo_total_cos_sim += cc_mlo_summed_cos_sim.item()

        evaluation_loss += total_loss.item()

# Compute performance metrics for logging
evaluation_cc_text_cos_sim = evaluation_cc_text_total_cos_sim / validation_data_size
evaluation_mlo_text_cos_sim = evaluation_mlo_text_total_cos_sim / validation_data_size
evaluation_cc_mlo_cos_sim = evaluation_cc_mlo_total_cos_sim / validation_data_size

z_cc_evaluation = torch.cat(all_evaluation_z_cc, dim=0)
z_mlo_evaluation = torch.cat(all_evaluation_z_mlo, dim=0)
z_text_evaluation = torch.cat(all_evaluation_z_text, dim=0)

cc_text_evaluation_metrics = compute_recall_k(z_cc_evaluation, z_text_evaluation)
mlo_text_evaluation_metrics = compute_recall_k(z_mlo_evaluation, z_text_evaluation)
cc_mlo_evaluation_metrics = compute_recall_k(z_cc_evaluation, z_mlo_evaluation)

avg_eval_loss = evaluation_loss / len(validation_dataloader)

print(evaluation_cc_text_cos_sim)
print(evaluation_mlo_text_cos_sim)
print(evaluation_cc_mlo_cos_sim)
print(cc_text_evaluation_metrics)
print(mlo_text_evaluation_metrics)
print(cc_mlo_evaluation_metrics)
print(avg_eval_loss)

0.3625710227272727
0.3938210227272727
0.7333096590909091
{'R@1': 25.0, 'R@2': 34.090909361839294, 'R@3': 43.18181872367859}
{'R@1': 27.272728085517883, 'R@2': 34.090909361839294, 'R@3': 43.18181872367859}
{'R@1': 77.27272510528564, 'R@2': 84.09090638160706, 'R@3': 86.36363744735718}
0.6269891912286932


Quantized Contrastive Model (Ablated -- no img_img_loss)  
0.3625710227272727  
0.3938210227272727  
0.7333096590909091  
{'R@1': 25.0, 'R@2': 34.090909361839294, 'R@3': 43.18181872367859}  
{'R@1': 27.272728085517883, 'R@2': 34.090909361839294, 'R@3': 43.18181872367859}  
{'R@1': 77.27272510528564, 'R@2': 84.09090638160706, 'R@3': 86.36363744735718}  
0.6269891912286932  

Quantized Gemma Merged with LoRA Adapter:  (Dequantized first, then merge, then quantized again)  
0.24884588068181818  
0.26171875  
0.7432528409090909  
{'R@1': 18.18181872367859, 'R@2': 20.454545319080353, 'R@3': 29.545453190803528}  
{'R@1': 22.727273404598236, 'R@2': 31.81818127632141, 'R@3': 38.63636255264282}  
{'R@1': 88.63636255264282, 'R@2': 95.45454382896423, 'R@3': 95.45454382896423}  
0.6533203125  

Dequantized Gemma Merged with LoRA Adapter:   (Dequantized first, then merge)  
0.21075994318181818  
0.21484375  
0.7681107954545454  
{'R@1': 18.18181872367859, 'R@2': 29.545453190803528, 'R@3': 34.090909361839294}  
{'R@1': 18.18181872367859, 'R@2': 36.36363744735718, 'R@3': 40.909090638160706}  
{'R@1': 88.63636255264282, 'R@2': 93.18181872367859, 'R@3': 95.45454382896423}  
0.6944691051136364  

Quantized Gemma + LoRA Adapter:  (Dequantized and load LoRA without merging)  
0.26118607954545453  
0.2731711647727273  
0.7386363636363636  
{'R@1': 18.18181872367859, 'R@2': 27.272728085517883, 'R@3': 31.81818127632141}  
{'R@1': 25.0, 'R@2': 34.090909361839294, 'R@3': 43.18181872367859}  
{'R@1': 84.09090638160706, 'R@2': 93.18181872367859, 'R@3': 95.45454382896423}  
0.6402254971590909  

## Others

In [None]:
# Uses 24GB of ram using config batch size 4, grad accumulation 1. 36gb after adding requires_grad_(True)
# Batch size, grad acc step, RAM, time trainning per epoch, initial loss, loss at the end of first epoch, loss at the beginning of second epoch
# 4, 1, 35.9gb, 10min40sec, 1.4, 1.18, 0.88 (forgot to activate image augmentation pipeline)
# 4, 8, 35.9gb, 10min42sec, 1.4, 1.37, 1.29 (forgot to activate image augmentation pipeline) -> probably dont use gradient accumulation, no benefits using it
# 4, 1, 35.9gb, 10min53sec, 1.4 , 1.2, 0.9 (with augmentation from here onward) -> can use higher batch size probably
# 8, 1, 55.9gb(63.9), 10min07sec, 2.08, 1.81(1.17)(0.816)(0.584), 1.14(0.82)(0.64)(0.535)
# (previously all using "eager" attention, now trying "sdpa")
# (accidentally used dataloader instead of validation_dataloader) 8, 1, 20.4gb (28.8 if add validation loop), 6min50sec (3min11sec for validation loop), 1.81, 1.3
# (correctly using validation_dataloader) 8, 1, 20.4gb (20.4gb if add validation loop), 6min50sec (33sec for validation loop, 10sec in subsequent validation), 1.81(1.2)(0.84)(0.59), 1.3 -> sdpa have similar performance with eager. validation is working correctly
# 32, 1, 44.0gb, 8min4sec (1min02sec validation loop), 3.34

In [None]:
# To find the best ratio range for data augmentation so that the original mammography dimension (tall rectangular) does not deviate too much after augmentation
import os
from PIL import Image
import numpy as np

def analyze_image_ratios(folder_path):
    ratios = []

    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif')):
            img_path = os.path.join(folder_path, filename)
            try:
                img = Image.open(img_path)
                width, height = img.size
                ratio = width / height
                ratios.append(ratio)
            except Exception as e:
                print(f"Error processing {filename}: {e}")

    ratios = np.array(ratios)
    avg_ratio = np.mean(ratios)
    std_ratio = np.std(ratios)

    return avg_ratio, std_ratio, ratios

# Usage
folder_path = "/content/drive/MyDrive/CDD-CESM-curated-dataset/images/"
avg, std, all_ratios = analyze_image_ratios(folder_path)

print(f"Average ratio (W/H): {avg:.4f}")
print(f"Standard deviation: {std:.4f}")
print(f"Total images: {len(all_ratios)}")

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/CDD-CESM-curated-dataset/images/'

In [None]:
example_image_path = data['train'][0]['cc_path']

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import ConnectionPatch
from PIL import Image
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import numpy as np

# 1. SETUP: Define your image path
example_image_path = data['train'][0]['cc_path']
# Load image (or create dummy if file missing for testing)
try:
    orig_img = Image.open(example_image_path).convert('RGB')
except (FileNotFoundError, OSError):
    print(f"File '{example_image_path}' not found. Generating a random 2000x2000 image for demonstration.")
    orig_img = Image.fromarray(np.uint8(np.random.rand(2000, 2000, 3) * 255))

# ---------------------------------------------------------
# PIPELINE EXECUTION (Step-by-Step)
# ---------------------------------------------------------

# 1. Calculate Crop Params (No resize yet)
# We simulate a crop that is 70%-90% of the original image
roi_selector = T.RandomResizedCrop(size=(896, 896), scale=(0.7, 0.7), ratio=(0.45, 0.65))
i, j, h, w = roi_selector.get_params(orig_img, scale=(0.7, 0.7), ratio=(0.45, 0.65))

# 2. Perform Crop
step1_crop = TF.crop(orig_img, i, j, h, w)

# 3. Perform Jitter
jitter_transform = T.ColorJitter(brightness=0.4, contrast=0.4)
step2_jitter = jitter_transform(step1_crop)

# 4. Perform Blur
blur_transform = T.GaussianBlur(kernel_size=39, sigma=2.0)
step3_blur = blur_transform(step2_jitter)

# 5. Perform Resize (Final 896x896)
step4_resize = TF.resize(step3_blur, size=(896, 896))

# ---------------------------------------------------------
# PLOTTING WITH ARROWS & RESIZING
# ---------------------------------------------------------

# We use width_ratios to make the last column (Resize) visually smaller (0.6 width)
# vs the others (1.0 width).
fig, axs = plt.subplots(1, 5, figsize=(24, 6),
                        gridspec_kw={'width_ratios': [1, 1, 1, 1, 0.6]})

# --- Plot 1: Original with Box ---
axs[0].imshow(orig_img)
axs[0].set_title(f"1. Original Input\n{orig_img.size}", fontsize=22, fontweight='bold')
# Draw the red box indicating what will be cropped
rect = patches.Rectangle((j, i), w, h, linewidth=4, edgecolor='#ff0055', facecolor='none')
axs[0].add_patch(rect)
axs[0].axis('off')

# --- Plot 2: Cropped Patch ---
axs[1].imshow(step1_crop)
axs[1].set_title(f"2. Cropped Patch\n{step1_crop.size}", fontsize=22, fontweight='bold')
axs[1].axis('off')

# --- Plot 3: Jittered ---
axs[2].imshow(step2_jitter)
axs[2].set_title("3. Color Jitter", fontsize=22, fontweight='bold')
axs[2].axis('off')

# --- Plot 4: Blurred ---
axs[3].imshow(step3_blur)
axs[3].set_title("4. Gaussian Blur", fontsize=22, fontweight='bold')
axs[3].axis('off')

# --- Plot 5: Final Resize (Will appear smaller due to width_ratios) ---
axs[4].imshow(step4_resize)
axs[4].set_title(f"5. Resized Final\n{step4_resize.size}", fontsize=22, fontweight='bold')
axs[4].axis('off')

# ---------------------------------------------------------
# DRAWING ARROWS
# ---------------------------------------------------------
def draw_arrow(ax_src, ax_dst):
    # Create an arrow extending from the right of source to left of dest
    con = ConnectionPatch(xyA=(1.0, 0.5), xyB=(0.0, 0.5),
                          coordsA="axes fraction", coordsB="axes fraction",
                          axesA=ax_src, axesB=ax_dst,
                          arrowstyle="simple,head_width=2,head_length=2", linewidth=2, color="black",
                          shrinkA=10, shrinkB=10) # shrink adds padding so arrow doesn't touch image
    ax_src.add_artist(con)

# Draw arrows between 0->1, 1->2, 2->3, 3->4
for k in range(4):
    draw_arrow(axs[k], axs[k+1])

plt.tight_layout()
plt.show()

# Optional: Save to file
# plt.savefig('pipeline_flowchart.png', dpi=300, bbox_inches='tight')

In [None]:
plt.savefig('pipeline_flowchart.png', dpi=900, bbox_inches='tight')

In [None]:
data['train'][1]