## On Zaratan, you may wanna do:

USE_SAMPLE_TEXT = False     # full text dataset
MAX_TRAIN_MM = None         # use all multimodal train samples
MAX_VAL_MM   = None         # use all multimodal val samples


In [None]:
!pip install -q transformers datasets accelerate


In [None]:
# ==== Global config, mount Drive, locate processed data + multimodal embeddings ====
from google.colab import drive
drive.mount("/content/drive")

import os
import glob
import math
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence

from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# ---------------- High-level toggles ----------------
# Use small sample_data vs full processed_data text CSV
USE_SAMPLE_TEXT = False     # True for Colab debugging, False for full decoder training


# Use subset of multimodal samples (train/val) for Colab
MAX_TRAIN_MM = None          # set to None on Zaratan if you want to use ALL
MAX_VAL_MM   = None          # set to None on Zaratan for ALL val

# Text-only decoder hyperparams
TEXT_NUM_EPOCHS  = 5
TEXT_LR          = 5e-5
TEXT_TRAIN_BS    = 8
TEXT_EVAL_BS     = 8
TEXT_GRAD_ACCUM  = 4
TEXT_MAX_LENGTH  = 256

# Multimodal decoder hyperparams
MM_NUM_EPOCHS    = 10
MM_LR            = 2e-5
MM_TRAIN_BS      = 4
MM_EVAL_BS       = 4
MM_GRAD_ACCUM    = 4
MM_MAX_LENGTH    = 256
COND_DIM         = 1536    # 768 (ViT pooled) + 768 (ClinicalBERT)

# ---------------- Locate processed_data.csv (text) ----------------
ROOT = "/content/drive"
print("\nSearching under:", ROOT)

proc_matches = glob.glob(os.path.join(ROOT, "**", "processed_data.csv"), recursive=True)
print("\nFound processed_data.csv matches:")
for m in proc_matches:
    print("  ", m)

if not proc_matches:
    raise FileNotFoundError(
        "Could not find processed_data.csv anywhere under /content/drive. "
        "Make sure the shared 'processed data' folder is visible in Drive."
    )

BASE_DIR = os.path.dirname(proc_matches[0])

print("\nInferred BASE_DIR:", BASE_DIR)
print("Contents of BASE_DIR:", os.listdir(BASE_DIR))

# ---------------- Locate multimodal_dataset_full.pt ----------------
mm_full_matches = glob.glob(os.path.join(ROOT, "**", "multimodal_dataset_full.pt"), recursive=True)
print("\nFound multimodal_dataset_full.pt matches:")
for m in mm_full_matches:
    print("  ", m)

if not mm_full_matches:
    raise FileNotFoundError(
        "Could not find multimodal_dataset_full.pt under /content/drive. "
        "Make sure the 'Dataset Embeddings' folder is visible / added to MyDrive."
    )

MM_FULL_PATH = mm_full_matches[0]
EMBEDDINGS_DIR = os.path.dirname(MM_FULL_PATH)

print("\nUsing EMBEDDINGS_DIR:", EMBEDDINGS_DIR)
print("Using MM_FULL_PATH:", MM_FULL_PATH)

# Quick sanity check
#tmp = torch.load(MM_FULL_PATH, map_location="cpu")
#print("\nMultimodal FULL dataset splits:", tmp.keys())
#print("Metadata:", tmp.get("metadata", {}))
#del tmp


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Device: cuda

Searching under: /content/drive

Found processed_data.csv matches:
   /content/drive/MyDrive/processed data/data/data/processed_data.csv

Inferred BASE_DIR: /content/drive/MyDrive/processed data/data/data
Contents of BASE_DIR: ['sample_data.csv', 'images', 'processed_data.csv', 'sample_decoder_tokenizer', 'sample_decoder_pretrained', 'sample_multimodal_decoder']

Found multimodal_dataset_full.pt matches:
   /content/drive/MyDrive/processed data/encoder for full data/multimodal_dataset_full.pt

Using EMBEDDINGS_DIR: /content/drive/MyDrive/processed data/encoder for full data
Using MM_FULL_PATH: /content/drive/MyDrive/processed data/encoder for full data/multimodal_dataset_full.pt


In [None]:
# ==== Load IU X-Ray text data (sample_data.csv or processed_data.csv) ====
import pandas as pd

if USE_SAMPLE_TEXT:
    csv_path = os.path.join(BASE_DIR, "sample_data.csv")
else:
    csv_path = os.path.join(BASE_DIR, "processed_data.csv")

print("Using CSV:", csv_path)
df = pd.read_csv(csv_path)

print("Columns:", list(df.columns))
print(df.head())

# Build a combined text field: <FINDINGS> ... <IMPRESSION> ...
def build_report(row):
    f = str(row.get("findings_final", "")).strip()
    i = str(row.get("impression_final", "")).strip()
    parts = []
    if f:
        parts.append("<FINDINGS> " + f)
    if i:
        parts.append("<IMPRESSION> " + i)
    return " ".join(parts)

df["report_text"] = df.apply(build_report, axis=1)

# Drop rows with empty text
df = df[df["report_text"].str.strip().astype(bool)].reset_index(drop=True)
print("\nAfter building report_text, total rows:", len(df))
print(df[["uid", "filename", "report_text"]].head())


Using CSV: /content/drive/MyDrive/processed data/data/data/processed_data.csv
Columns: ['uid', 'filename', 'projection', 'findings_final', 'impression_final', 'full_report', 'findings_len', 'impression_len', 'full_report_len', 'MeSH', 'Problems', 'split']
   uid                filename projection  \
0    1  1_IM-0001-4001.dcm.png    Frontal   
1    1  1_IM-0001-3001.dcm.png    Lateral   
2    2  2_IM-0652-1001.dcm.png    Frontal   
3    2  2_IM-0652-2001.dcm.png    Lateral   
4    3  3_IM-1384-1001.dcm.png    Frontal   

                                      findings_final  \
0  The cardiac silhouette and mediastinum size ar...   
1  The cardiac silhouette and mediastinum size ar...   
2  Borderline cardiomegaly. Midline sternotomy . ...   
3  Borderline cardiomegaly. Midline sternotomy . ...   
4                                             normal   

                                    impression_final  \
0                                   Normal chest x-.   
1                       

In [None]:
# ==== Train/val/test split using 'split' column and wrap into HF Datasets ====
from datasets import Dataset, DatasetDict

if "split" not in df.columns:
    raise ValueError("Expected a 'split' column in the CSV (train/val/test).")

train_df = df[df["split"] == "train"].reset_index(drop=True)
val_df   = df[df["split"] == "val"].reset_index(drop=True)
test_df  = df[df["split"] == "test"].reset_index(drop=True)

print("Full sizes -> train/val/test:", len(train_df), len(val_df), len(test_df))

# For Colab debugging, optionally limit sample sizes even on full CSV
if USE_SAMPLE_TEXT:
    # small sample only uses what's in sample_data.csv, so just keep everything there
    pass
else:
    # If you want to subsample for quick runs, uncomment:
    # train_df = train_df.sample(n=min(len(train_df), 2000), random_state=42).reset_index(drop=True)
    # val_df   = val_df.sample(n=min(len(val_df),  250), random_state=42).reset_index(drop=True)
    # test_df  = test_df.sample(n=min(len(test_df), 250), random_state=42).reset_index(drop=True)
    pass

print("Using sizes -> train/val/test:", len(train_df), len(val_df), len(test_df))

ds_train = Dataset.from_pandas(train_df[["report_text"]])
ds_val   = Dataset.from_pandas(val_df[["report_text"]])
ds_test  = Dataset.from_pandas(test_df[["report_text"]])

text_ds = DatasetDict(
    train=ds_train,
    validation=ds_val,
    test=ds_test,
)

print(text_ds)


Full sizes -> train/val/test: 5223 1122 1121
Using sizes -> train/val/test: 5223 1122 1121
DatasetDict({
    train: Dataset({
        features: ['report_text'],
        num_rows: 5223
    })
    validation: Dataset({
        features: ['report_text'],
        num_rows: 1122
    })
    test: Dataset({
        features: ['report_text'],
        num_rows: 1121
    })
})


In [None]:
# ==== Initialize GPT-2 tokenizer with medical special tokens ====
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Special tokens for our task
special_tokens = {
    "bos_token": "<BOS>",
    "eos_token": "<EOS>",
    "pad_token": "<PAD>",
    "additional_special_tokens": ["<FINDINGS>", "<IMPRESSION>"],
}

tokenizer.add_special_tokens(special_tokens)
print("Special tokens map:", tokenizer.special_tokens_map)

# Some core medical terms to add to vocab
MED_TOKENS = [
    "cardiomegaly", "atelectasis", "consolidation", "effusion", "pneumothorax",
    "edema", "collapse", "opacity", "opacities", "hyperinflation", "fibrosis",
    "infiltrate", "infiltrates", "pleural", "interstitial"
]

added = tokenizer.add_tokens(MED_TOKENS)
print(f"Added medical tokens: {added}")
print("New vocab size:", len(tokenizer))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Special tokens map: {'bos_token': '<BOS>', 'eos_token': '<EOS>', 'unk_token': '<|endoftext|>', 'pad_token': '<PAD>', 'additional_special_tokens': ['<FINDINGS>', '<IMPRESSION>']}
Added medical tokens: 15
New vocab size: 50276


In [None]:
# ==== Load GPT-2 model and resize embeddings for new tokens ====
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer))

# Set BOS/EOS/PAD IDs
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

model = model.to(device)
print("bos_token_id:", model.config.bos_token_id)
print("eos_token_id:", model.config.eos_token_id)
print("pad_token_id:", model.config.pad_token_id)
print("Hidden size:", model.config.n_embd)


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


bos_token_id: 50257
eos_token_id: 50258
pad_token_id: 50259
Hidden size: 768


In [None]:
# ==== Tokenize text dataset for Causal LM ====
def tokenize_function(batch):
    texts = batch["report_text"]
    encoded = tokenizer(
        [tokenizer.bos_token + " " + t for t in texts],
        truncation=True,
        max_length=TEXT_MAX_LENGTH,
        return_attention_mask=True,
    )
    return encoded

tokenized_text_ds = text_ds.map(tokenize_function, batched=True, remove_columns=["report_text"])

# For causal LM, labels = input_ids
def add_labels(batch):
    batch["labels"] = batch["input_ids"].copy()
    return batch

tokenized_text_ds = tokenized_text_ds.map(add_labels, batched=True)

print(tokenized_text_ds)


Map:   0%|          | 0/5223 [00:00<?, ? examples/s]

Map:   0%|          | 0/1122 [00:00<?, ? examples/s]

Map:   0%|          | 0/1121 [00:00<?, ? examples/s]

Map:   0%|          | 0/5223 [00:00<?, ? examples/s]

Map:   0%|          | 0/1122 [00:00<?, ? examples/s]

Map:   0%|          | 0/1121 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 5223
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1122
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1121
    })
})


In [None]:
# ==== Rebuild tokenized dataset + train text-only GPT-2 decoder (fixed) ====
from datasets import DatasetDict
from transformers import DataCollatorForLanguageModeling, TrainingArguments, Trainer
from collections import defaultdict
import math
import os

# 1) Rebuild tokenized_text_ds cleanly (no manual labels)
def tokenize_for_lm(batch):
    texts = batch["report_text"]
    encoded = tokenizer(
        [tokenizer.bos_token + " " + t for t in texts],
        truncation=True,
        max_length=TEXT_MAX_LENGTH,
        return_attention_mask=True,
        padding=False,   # leave ragged; collator will pad per batch
    )
    return encoded

tokenized_text_ds = text_ds.map(
    tokenize_for_lm,
    batched=True,
    remove_columns=["report_text"],  # keep only token ids & masks
)

print("Tokenized dataset:")
print(tokenized_text_ds)
for split in tokenized_text_ds:
    print(split, tokenized_text_ds[split][0])

# 2) Data collator: will create labels = input_ids and pad batch
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,   # causal LM
)

# 3) Output dirs depend on whether we are using sample or full dataset
if USE_SAMPLE_TEXT:
    DECODER_TOKENIZER_DIR = os.path.join(BASE_DIR, "sample_decoder_tokenizer")
    DECODER_PRETRAINED_DIR = os.path.join(BASE_DIR, "sample_decoder_pretrained")
else:
    DECODER_TOKENIZER_DIR = os.path.join(BASE_DIR, "decoder_tokenizer")
    DECODER_PRETRAINED_DIR = os.path.join(BASE_DIR, "decoder_pretrained")

os.makedirs(DECODER_TOKENIZER_DIR, exist_ok=True)
os.makedirs(DECODER_PRETRAINED_DIR, exist_ok=True)

print("Decoder tokenizer dir:", DECODER_TOKENIZER_DIR)
print("Decoder model dir:", DECODER_PRETRAINED_DIR)

# 4) TrainingArguments + Trainer
training_args = TrainingArguments(
    output_dir=DECODER_PRETRAINED_DIR,
    per_device_train_batch_size=TEXT_TRAIN_BS,
    per_device_eval_batch_size=TEXT_EVAL_BS,
    gradient_accumulation_steps=TEXT_GRAD_ACCUM,
    num_train_epochs=TEXT_NUM_EPOCHS,
    learning_rate=TEXT_LR,
    warmup_steps=50,
    weight_decay=0.01,
    logging_steps=50,
    fp16=True if device == "cuda" else False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_text_ds["train"],
    eval_dataset=tokenized_text_ds["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

# 5) Train
trainer.train()

# 6) Regular eval + perplexity
eval_results = trainer.evaluate()
print(eval_results)

if "eval_loss" in eval_results:
    perplexity = math.exp(eval_results["eval_loss"])
    print("Validation Perplexity (text-only decoder):", perplexity)
else:
    print("No eval_loss found.")

# 7) Epoch-level summary
epoch_train_losses = defaultdict(list)
epoch_eval_losses = defaultdict(list)

for entry in trainer.state.log_history:
    if "loss" in entry and "epoch" in entry and "eval_loss" not in entry:
        epoch_train_losses[entry["epoch"]].append(entry["loss"])
    if "eval_loss" in entry and "epoch" in entry:
        epoch_eval_losses[entry["epoch"]].append(entry["eval_loss"])

print("\nEpoch | Train loss | Eval loss")
print("------|-----------|----------")
all_epochs = sorted(set(list(epoch_train_losses.keys()) + list(epoch_eval_losses.keys())))
for ep in all_epochs:
    train_loss = epoch_train_losses[ep][-1] if ep in epoch_train_losses else None
    eval_loss = epoch_eval_losses[ep][-1] if ep in epoch_eval_losses else None
    t_str = f"{train_loss:.4f}" if train_loss is not None else "   -   "
    e_str = f"{eval_loss:.4f}" if eval_loss is not None else "   -   "
    print(f"{ep:5.1f} | {t_str:9} | {e_str}")


Map:   0%|          | 0/5223 [00:00<?, ? examples/s]

Map:   0%|          | 0/1122 [00:00<?, ? examples/s]

Map:   0%|          | 0/1121 [00:00<?, ? examples/s]

Tokenized dataset:
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 5223
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1122
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1121
    })
})
train {'input_ids': [50257, 220, 50260, 383, 26077, 41834, 290, 16957, 459, 259, 388, 2546, 389, 1626, 3487, 7095, 13, 1318, 318, 645, 45105, 220, 50267, 13, 1318, 318, 645, 25397, 220, 50264, 13, 1318, 389, 645, 286, 257, 220, 50275, 220, 50265, 13, 1318, 318, 645, 2370, 286, 220, 50266, 13, 220, 50261, 14435, 7721, 2124, 34507], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
validation {'input_ids': [50257, 220, 50260, 1318, 389, 42864, 24537, 220, 29446, 290, 435, 303, 6192, 220, 50270, 6414, 351, 10726, 265

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 50258, 'bos_token_id': 50257, 'pad_token_id': 50259}.
[34m[1mwandb[0m: Currently logged in as: [33mnir64[0m ([33mnir64-university-of-maryland[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
50,4.33
100,2.5431
150,1.9659
200,1.7206
250,1.5388
300,1.4944
350,1.3674
400,1.3083
450,1.2442
500,1.2534


{'eval_loss': 1.1988717317581177, 'eval_runtime': 2.9917, 'eval_samples_per_second': 375.039, 'eval_steps_per_second': 47.131, 'epoch': 5.0}
Validation Perplexity (text-only decoder): 3.316373052695281

Epoch | Train loss | Eval loss
------|-----------|----------
  0.3 | 4.3300    |    -   
  0.6 | 2.5431    |    -   
  0.9 | 1.9659    |    -   
  1.2 | 1.7206    |    -   
  1.5 | 1.5388    |    -   
  1.8 | 1.4944    |    -   
  2.1 | 1.3674    |    -   
  2.4 | 1.3083    |    -   
  2.7 | 1.2442    |    -   
  3.0 | 1.2534    |    -   
  3.4 | 1.1657    |    -   
  3.7 | 1.1690    |    -   
  4.0 | 1.1804    |    -   
  4.3 | 1.1465    |    -   
  4.6 | 1.1131    |    -   
  4.9 | 1.1211    |    -   
  5.0 |    -      | 1.1989


In [None]:
# ==== Save tokenizer + text-only decoder weights ====
tokenizer.save_pretrained(DECODER_TOKENIZER_DIR)
trainer.save_model(DECODER_PRETRAINED_DIR)

print("Saved tokenizer to:", DECODER_TOKENIZER_DIR)
print("Saved text decoder model to:", DECODER_PRETRAINED_DIR)


Saved tokenizer to: /content/drive/MyDrive/processed data/data/data/decoder_tokenizer
Saved text decoder model to: /content/drive/MyDrive/processed data/data/data/decoder_pretrained


# ...

In [None]:
# ==== Load multimodal_dataset_full.pt and inspect ====
mm_data_full = torch.load(MM_FULL_PATH, map_location="cpu")

print("Multimodal full dataset splits:", mm_data_full.keys())
print("Metadata:", mm_data_full.get("metadata", {}))

example = mm_data_full["train"][0]
print("\nExample train sample keys:", example.keys())
print("Image embedding shape:", example["image_emb"].shape)
print("Text embedding shape:", example["text_emb"].shape)
print("\nExample impression:", example["impression"])
print("\nExample full_report snippet:\n", example["full_report"][:300], "...")


Multimodal full dataset splits: dict_keys(['train', 'val', 'test', 'metadata'])
Metadata: {'creation_date': '2025-12-02 23:01:23', 'vision_model': 'google/vit-base-patch16-224', 'text_model': 'emilyalsentzer/Bio_ClinicalBERT', 'image_embedding_shape': [197, 768], 'text_embedding_shape': [768], 'splits': {'train': {'num_samples': 5223, 'image_embedding_shape': [197, 768], 'text_embedding_shape': [768]}, 'val': {'num_samples': 1122, 'image_embedding_shape': [197, 768], 'text_embedding_shape': [768]}, 'test': {'num_samples': 1121, 'image_embedding_shape': [197, 768], 'text_embedding_shape': [768]}}}

Example train sample keys: dict_keys(['filename', 'image_emb', 'text_emb', 'impression', 'full_report', 'mesh', 'problems', 'projection'])
Image embedding shape: torch.Size([197, 768])
Text embedding shape: torch.Size([768])

Example impression: Normal chest x-.

Example full_report snippet:
 Indication: Positive TB test Comparison: None. Findings: The cardiac silhouette and mediastinum size 

In [None]:
# ==== Build multimodal dataset (ViT + ClinicalBERT → condition vector) ====
from torch.utils.data import Dataset

class MultimodalReportDataset(Dataset):
    def __init__(self, samples, tokenizer, max_length=256):
        self.samples = samples
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        # Prefer full_report; fallback to impression if needed
        text = s.get("full_report", "") or s.get("impression", "") or ""
        text = str(text)

        prompt = (self.tokenizer.bos_token + " " + text).strip()

        encoded = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        input_ids = encoded["input_ids"][0]
        attention_mask = encoded["attention_mask"][0]
        labels = input_ids.clone()

        # Condition vector: mean ViT patch embedding + ClinicalBERT embedding
        img_emb = s["image_emb"]              # (197, 768)
        txt_emb = s["text_emb"]               # (768,)
        img_pooled = img_emb.mean(dim=0)      # (768,)
        cond_vec = torch.cat([img_pooled, txt_emb], dim=-1)  # (1536,)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "cond_vec": cond_vec,
        }

# Optionally subsample for Colab; on Zaratan set MAX_TRAIN_MM / MAX_VAL_MM = None
train_samples_mm = mm_data_full["train"]
val_samples_mm   = mm_data_full["val"]

if MAX_TRAIN_MM is not None:
    train_samples_mm = train_samples_mm[:MAX_TRAIN_MM]
if MAX_VAL_MM is not None:
    val_samples_mm   = val_samples_mm[:MAX_VAL_MM]

mm_train_ds = MultimodalReportDataset(train_samples_mm, tokenizer, max_length=MM_MAX_LENGTH)
mm_val_ds   = MultimodalReportDataset(val_samples_mm,   tokenizer, max_length=MM_MAX_LENGTH)

print("Multimodal train size:", len(mm_train_ds))
print("Multimodal val size:", len(mm_val_ds))


Multimodal train size: 5223
Multimodal val size: 1122


In [None]:
# ==== Collate function for multimodal batches ====
def multimodal_collate_fn(batch):
    input_ids = [b["input_ids"] for b in batch]
    attention_masks = [b["attention_mask"] for b in batch]
    labels = [b["labels"] for b in batch]
    cond_vecs = [b["cond_vec"] for b in batch]

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)
    labels = pad_sequence(labels, batch_first=True, padding_value=-100)
    cond_vecs = torch.stack(cond_vecs, dim=0)  # (B, 1536)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_masks,
        "labels": labels,
        "cond_vec": cond_vecs,
    }


In [None]:
# ==== Load text decoder checkpoint and wrap with conditioning ====
print("Device:", device)

# These must match the dirs used when training the text decoder
if USE_SAMPLE_TEXT:
    DECODER_TOKENIZER_DIR = os.path.join(BASE_DIR, "sample_decoder_tokenizer")
    DECODER_PRETRAINED_DIR = os.path.join(BASE_DIR, "sample_decoder_pretrained")
else:
    DECODER_TOKENIZER_DIR = os.path.join(BASE_DIR, "decoder_tokenizer")
    DECODER_PRETRAINED_DIR = os.path.join(BASE_DIR, "decoder_pretrained")

print("Using decoder tokenizer from:", DECODER_TOKENIZER_DIR)
print("Using decoder weights from:", DECODER_PRETRAINED_DIR)

tokenizer = AutoTokenizer.from_pretrained(DECODER_TOKENIZER_DIR)
gpt2_base = AutoModelForCausalLM.from_pretrained(DECODER_PRETRAINED_DIR).to(device)

print("bos_token:", tokenizer.bos_token)
print("eos_token:", tokenizer.eos_token)
print("pad_token:", tokenizer.pad_token)
print("GPT-2 hidden size:", gpt2_base.config.n_embd)

class GPT2WithConditioning(nn.Module):
    def __init__(self, base_model, cond_dim=1536):
        super().__init__()
        self.gpt2 = base_model
        self.cond_proj = nn.Linear(cond_dim, self.gpt2.config.n_embd)

    def forward(self, input_ids=None, attention_mask=None, labels=None, cond_vec=None):
        # cond_vec: (B, cond_dim)
        cond_emb = self.cond_proj(cond_vec)    # (B, hidden)
        cond_emb = cond_emb.unsqueeze(1)       # (B, 1, hidden)

        token_emb = self.gpt2.transformer.wte(input_ids)  # (B, T, hidden)
        inputs_embeds = torch.cat([cond_emb, token_emb], dim=1)  # (B, 1+T, hidden)

        if attention_mask is not None:
            prefix_mask = torch.ones(
                (attention_mask.size(0), 1),
                dtype=attention_mask.dtype,
                device=attention_mask.device,
            )
            attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)

        if labels is not None:
            prefix_labels = -100 * torch.ones(
                (labels.size(0), 1),
                dtype=labels.dtype,
                device=labels.device,
            )
            labels = torch.cat([prefix_labels, labels], dim=1)

        outputs = self.gpt2(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels,
        )
        return outputs

mm_model = GPT2WithConditioning(gpt2_base, cond_dim=COND_DIM).to(device)
mm_model.gpt2.config.use_cache = False

# Untie shared weights so safetensors can save without error
with torch.no_grad():
    mm_model.gpt2.lm_head.weight = nn.Parameter(
        mm_model.gpt2.lm_head.weight.clone()
    )

mm_model


Device: cuda
Using decoder tokenizer from: /content/drive/MyDrive/processed data/data/data/decoder_tokenizer
Using decoder weights from: /content/drive/MyDrive/processed data/data/data/decoder_pretrained
bos_token: <BOS>
eos_token: <EOS>
pad_token: <PAD>
GPT-2 hidden size: 768


GPT2WithConditioning(
  (gpt2): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50276, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D(nf=2304, nx=768)
            (c_proj): Conv1D(nf=768, nx=768)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D(nf=3072, nx=768)
            (c_proj): Conv1D(nf=768, nx=3072)
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_f

In [None]:
# ==== Train multimodal decoder (encoder-conditioned GPT-2) ====
from transformers import TrainingArguments, Trainer

MM_OUTPUT_DIR = os.path.join(
    BASE_DIR,
    "sample_multimodal_decoder" if USE_SAMPLE_TEXT else "full_multimodal_decoder",
)
os.makedirs(MM_OUTPUT_DIR, exist_ok=True)
print("Multimodal decoder output dir:", MM_OUTPUT_DIR)

mm_training_args = TrainingArguments(
    output_dir=MM_OUTPUT_DIR,
    num_train_epochs=MM_NUM_EPOCHS,
    per_device_train_batch_size=MM_TRAIN_BS,
    per_device_eval_batch_size=MM_EVAL_BS,
    gradient_accumulation_steps=MM_GRAD_ACCUM,
    learning_rate=MM_LR,
    warmup_steps=0,
    weight_decay=0.01,
    logging_steps=5,
    fp16=True if device == "cuda" else False,
)

mm_trainer = Trainer(
    model=mm_model,
    args=mm_training_args,
    train_dataset=mm_train_ds,
    eval_dataset=mm_val_ds,
    tokenizer=tokenizer,
    data_collator=multimodal_collate_fn,
)

mm_trainer.train()

mm_eval_results = mm_trainer.evaluate()
print("Multimodal eval results:", mm_eval_results)

# Save core GPT-2 decoder, tokenizer, and full multimodal wrapper
os.makedirs(os.path.join(MM_OUTPUT_DIR, "mm_decoder_gpt2"), exist_ok=True)
os.makedirs(os.path.join(MM_OUTPUT_DIR, "mm_decoder_tokenizer"), exist_ok=True)

mm_model.gpt2.save_pretrained(os.path.join(MM_OUTPUT_DIR, "mm_decoder_gpt2"))
tokenizer.save_pretrained(os.path.join(MM_OUTPUT_DIR, "mm_decoder_tokenizer"))

torch.save(mm_model.state_dict(), os.path.join(MM_OUTPUT_DIR, "mm_model_state.pt"))

print("Saved GPT-2 core, tokenizer, and multimodal wrapper to:", MM_OUTPUT_DIR)


Multimodal decoder output dir: /content/drive/MyDrive/processed data/data/data/full_multimodal_decoder


  mm_trainer = Trainer(


Step,Training Loss
5,2.7421
10,2.2546
15,2.0608
20,1.9716
25,1.7707
30,1.6852
35,1.6495
40,1.6302
45,1.6115
50,1.5441


Multimodal eval results: {'eval_loss': 1.164597988128662, 'eval_runtime': 6.3049, 'eval_samples_per_second': 177.956, 'eval_steps_per_second': 44.568, 'epoch': 10.0}
Saved GPT-2 core, tokenizer, and multimodal wrapper to: /content/drive/MyDrive/processed data/data/data/full_multimodal_decoder


In [None]:
# ==== Multimodal eval: perplexity + summary ====
import pandas as pd

if "eval_loss" in mm_eval_results:
    mm_val_ppl = math.exp(mm_eval_results["eval_loss"])
    print("Multimodal validation perplexity:", mm_val_ppl)
else:
    mm_val_ppl = None
    print("No eval_loss found for multimodal model.")

mm_summary = {
    "model": ["ViT + ClinicalBERT → GPT-2 (conditioned)"],
    "eval_loss": [mm_eval_results.get("eval_loss", None)],
    "val_perplexity": [mm_val_ppl],
}
mm_summary_df = pd.DataFrame(mm_summary)
display(mm_summary_df)


Multimodal validation perplexity: 3.204634323170742


Unnamed: 0,model,eval_loss,val_perplexity
0,ViT + ClinicalBERT → GPT-2 (conditioned),1.164598,3.204634


In [None]:
# ==== Optional: Conditioned generation for sanity check ====
def generate_conditioned(idx=0, max_new_tokens=80):
    mm_model.eval()
    with torch.no_grad():
        s = mm_val_ds.samples[idx]

        img_emb = s["image_emb"]
        txt_emb = s["text_emb"]
        img_pooled = img_emb.mean(dim=0)
        cond_vec = torch.cat([img_pooled, txt_emb], dim=-1).unsqueeze(0).to(device)

        generated_ids = tokenizer(
            tokenizer.bos_token,
            return_tensors="pt"
        )["input_ids"].to(device)

        for _ in range(max_new_tokens):
            token_emb = mm_model.gpt2.transformer.wte(generated_ids)
            cond_emb = mm_model.cond_proj(cond_vec).unsqueeze(1)
            inputs_embeds = torch.cat([cond_emb, token_emb], dim=1)

            attn_mask = torch.ones(
                inputs_embeds.size()[:2],
                dtype=torch.long,
                device=device,
            )

            outputs = mm_model.gpt2(
                inputs_embeds=inputs_embeds,
                attention_mask=attn_mask,
            )
            next_token_logits = outputs.logits[:, -1, :]
            next_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            generated_ids = torch.cat([generated_ids, next_id], dim=1)

            if tokenizer.eos_token_id is not None and next_id.item() == tokenizer.eos_token_id:
                break

        gen_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False)

        print("=== CONDITIONED GENERATION (val sample idx =", idx, ") ===\n")
        print(gen_text)
        print("\n---------------- Ground truth (for reference) ----------------")
        print("\n[Impression]\n", s["impression"])
        print("\n[Full report snippet]\n", s["full_report"][:400], "...")

generate_conditioned(idx=0, max_new_tokens=80)


=== CONDITIONED GENERATION (val sample idx = 0 ) ===

<BOS> Indication: -year-old female with chest pain. Comparison: None. Findings: The cardiomediastinal silhouette is within normal limits for appearance. No focal areas of pulmonary consolidation. No pleural effusion. No pneumothorax. Minimal degenerative changes of the thoracic spine. Impression: 1. No acute intrathoracic abnormality. 2

---------------- Ground truth (for reference) ----------------

[Impression]
 1. Bullous emphysema and interstitial fibrosis. 2. Probably scarring in the left apex, although difficult to exclude a cavitary lesion. 3. Opacities in the bilateral upper lobes could represent scarring, however the absence of comparison exam, recommend short interval followup radiograph or CT thorax to document resolution.

[Full report snippet]
 Indication: -year-old with . Comparison: None available Findings: There are diffuse bilateral interstitial and alveolar opacities consistent with chronic obstructive lung disease