# UniXGen on Colab

### UniXGen Roadmap

| Component                | UniXGen (Original Paper)                      | Colab Reproduction                                                                                                                           |
|:------------------------:|:----------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------|
| **Data**                 | MIMIC‑CXR‑JPG images + reports                 | ✔️ **Filtered metadata** (strict image–report matching) with **view‑counts** & **unique‑view summary**<br/>✔️ **Single‑view** & **Multi‑view** CSVs for ablation |
| **Tokenizer**            | ByteLevelBPETokenizer (BBPE)                   | Unchanged                                                                                                                                     |
| **Image Encoder**        | VQGAN                                           | Unchanged                                                                                                                                     |
| **Text Encoder**         | Transformer‑based BBPE embedding               | Unchanged                                                                                                                                     |
| **Fusion Module**        | Joint Transformer encoder block                | Unchanged                                                                                                                                     |
| **Loss Function**        | Contrastive + autoregressive                   | Unchanged                                                                                                                                     |
| **Decoder**              | Transformer decoder (report + image)           | Unchanged                                                                                                                                     |
| **Training Framework**   | PyTorch Lightning                              | ✔️ **on_test_epoch_end** callback to compute BLEU automatically<br/>✔️ Test outputs saved as `.pt`                                              |
| **Generation**           | Autoregressive sampling                        | ✔️ **Top‑p (=0.9) + temperature (=0.7)** sampling (configurable)<br/>✔️ `--random_mode_order` flag (fixed vs. random ordering) for ablation        |
| **Ablation Sweeper**     | —                                              | ✔️ `unified_run_ablation.py`: grid sweep over `under_sample`, `max_img_num`, `target_count`                                                    |
| **Token‑Order Ablation** | —                                              | ✔️ `unified_run.py` + `--random_mode_order` (`True` vs. `False`) via `setup_modes()`                                                           |
| **Evaluation Metrics**   | BLEU (text), FID (image), 14‑disease AUROC (CheXpert)                       | ✔️ **BLEU**, **BERTScore** (P/R/F1),<br/>✔️ Per‑sample & summary CSV exports                           |
| **Checkpoints**          | Pretrained VQGAN + UniXGen                     | ✔️ Load original Lightning CKPT for inference<br/>✔️ Fully configurable via CLI (e.g. `--max_img_num`, `--target_count`, `--random_mode_order`) |


## Setup


In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Github Repo Clone

In [None]:
!git clone https://github.com/slyncrafty/DL4H-PRJ-Group.git UniXGen

## Project Location

In [None]:
PRJ_ROOT ='/content/drive/MyDrive/UniXGen' ## set it to correct drive location
%cd {PRJ_ROOT}

## Directory Setup & Download Files

In [None]:
!mkdir data
!mkdir data/images     # Place MIMIC-CXR-JPG images
!mkdir data/reports    # Place MIMIC-CXR Database reports
!mkdir mimiccxr_vqgan  # Place Chest X-ray Tokenizer
!mkdir ckpt            # Place .ckpt Model File 
!mkdir output          
!mkdir output/decoded_images
!mkdir output/decoded_reports

🔧 Download MIMIC-CXR-JPG images & reports

- You must be a credential user defined in PhysioNet to access the data.
- Download chest X-rays from [MIMIC-CXR-JPG](https://physionet.org/content/mimic-cxr-jpg/2.0.0/) and Place images under **data/images/**
- Download reports from [MIMIC-CXR Database](https://physionet.org/content/mimic-cxr/2.0.0/) and Place reports under **data/reports/**

🔧 Download VQGAN Tokenizer

- Download [Chest X-ray Tokenizer(VQGAN))](https://drive.google.com/drive/folders/1Ia_GqRrmZ8g6md02TC5_nkrGn6eUwVaG) and Place under **/mimiccxr_vqgan**

🔧 Place model file under **/ckpt**

- Download [Pre-updated UniXGen Model](https://drive.google.com/file/d/1LuZXq7DpQUV9cgWTLK6SRvlmSHu_a5E1/view?usp=drive_link) and Place model file under **/ckpt**

🔑 Replace file: ./taming-transformers/taming/data/utils.py with [utils.py](https://drive.google.com/file/d/1NCO8hojet42JdrgX1vKV3uMPpCLBWDw8/view?usp=drive_link)

## Installations

### Installing Required Packages / Libraries

In [None]:
%cd {PRJ_ROOT}
%pip install --upgrade pip
%pip install -r requirements.txt
%pip install pytorch-lightning==2.0.9 --force-reinstall
%pip install --force-reinstall torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118
%pip uninstall -y numpy
%pip install numpy==1.24.4
%pip uninstall -y jax jaxlib
%pip install --upgrade jax==0.4.23 jaxlib==0.4.23

In [None]:
%cd {PRJ_ROOT}
!git clone https://github.com/CompVis/taming-transformers.git
%cd taming-transformers
!pip install -e .
%cd {PRJ_ROOT}

### Check Libraries

In [None]:
import numpy as np;        print("numpy",     np.__version__)
import scipy;              print("scipy",     scipy.__version__)
import torch;              print("torch",     torch.__version__, "cuda", torch.version.cuda)
import torchvision;        print("torchvision", torchvision.__version__)
import pytorch_lightning as pl; print("PL",       pl.__version__)
import torchmetrics;       print("torchmetrics", torchmetrics.__version__)
import transformers;       print("transformers", transformers.__version__)
import omegaconf;          print("omegaconf", omegaconf.__version__)

numpy 1.24.4
scipy 1.10.1
torch 2.7.0+cu126 cuda 12.6
torchvision 0.15.2+cu118
PL 2.0.9
torchmetrics 1.7.1
transformers 4.37.2


In [None]:
import jax;                print("jax",       jax.__version__)
import jaxlib;             print("jaxlib",    jaxlib.__version__)

jax 0.4.23
jaxlib 0.4.23


### File location check(Optional)

In [None]:
import os
import torch
import pandas as pd
from tokenizers import ByteLevelBPETokenizer

### Adjust these paths to match your config
ckpt_path = 'ckpt/unixgen_lightning.ckpt'
vocab_path = 'BBPE_tokenizer/vocab.json'
merges_path = 'BBPE_tokenizer/merges.txt'
test_meta_file = 'metadata/mimiccxr_test_filtered.csv'
output_dir = 'output'

print("🔍 Checking paths and files...")

# Check checkpoint
if os.path.isfile(ckpt_path):
    print(f"✅ Checkpoint file found: {ckpt_path}")
    try:
        ckpt = torch.load(ckpt_path, map_location='cpu')
        print(f"✅ Checkpoint loaded successfully, keys: {list(ckpt.keys())}")
    except Exception as e:
        print(f"❌ Failed to load checkpoint: {e}")
else:
    print(f"❌ Checkpoint file NOT found: {ckpt_path}")

# Check tokenizer files
if os.path.isfile(vocab_path) and os.path.isfile(merges_path):
    print(f"✅ Tokenizer vocab and merges found.")
    try:
        tokenizer = ByteLevelBPETokenizer(vocab_path, merges_path)
        tokenizer.add_special_tokens(["[PAD]", "[SOS]", "[EOS]", "[SEP]", "[MASK]"])
        print("✅ Tokenizer loaded and special tokens added.")
    except Exception as e:
        print(f"❌ Failed to load tokenizer: {e}")
else:
    print(f"❌ Missing vocab or merges file: {vocab_path}, {merges_path}")

# Check test metadata
if os.path.isfile(test_meta_file):
    print(f"✅ Test metadata file found: {test_meta_file}")
    try:
        df = pd.read_csv(test_meta_file)
        if df.empty:
            print("⚠️ Test metadata CSV is EMPTY.")
        else:
            print(f"✅ Test metadata CSV loaded with {len(df)} rows.")
    except Exception as e:
        print(f"❌ Failed to load CSV: {e}")
else:
    print(f"❌ Test metadata file NOT found: {test_meta_file}")

# Check output directory
if os.path.isdir(output_dir):
    print(f"✅ Output directory exists: {output_dir}")
else:
    try:
        os.makedirs(output_dir, exist_ok=True)
        print(f"✅ Output directory created: {output_dir}")
    except Exception as e:
        print(f"❌ Failed to create output directory: {e}")

print("\n🛡 Validation complete.")


🔍 Checking paths and files...
✅ Checkpoint file found: ckpt/unixgen_lightning.ckpt
✅ Checkpoint loaded successfully, keys: ['pytorch-lightning_version', 'state_dict', 'callbacks', 'hparams_name', 'hyper_parameters', 'global_step', 'epoch', 'loops', 'legacy_pytorch-lightning_version']
✅ Tokenizer vocab and merges found.
✅ Tokenizer loaded and special tokens added.
✅ Test metadata file found: metadata/mimiccxr_test_filtered.csv
✅ Test metadata CSV loaded with 159 rows.
✅ Output directory exists: output

🛡 Validation complete.


## Pre-processing

#### Create Filtered metadata

In [None]:
%cd {PRJ_ROOT}
!pwd

/content/drive/MyDrive/UniXGen
/content/drive/MyDrive/UniXGen


In [None]:
"""
Create a filtered metadata .csv file from image data available in the location.
Add a column for unique view counts
"""

import os
from glob import glob
import pandas as pd

def generate_filtered_metadata_with_summary(
    original_csv: str,
    image_dir:      str,
    detailed_output_csv: str,
    summary_output_csv:  str,
    missing_folder_report:   str = None,
    missing_metadata_report: str = None
):
    """
    1) Read raw metadata (no headers) → columns: dicom_id, subject_id, study_id, view, count
    2) Find all study folders under `image_dir` that contain at least one .jpg
    3) Report studies in metadata but no images, and vice versa (optional CSVs)
    4) Write filtered detailed metadata (5 columns, no extras)
    5) Write summary per-study CSV with total_images_in_study & unique_views_in_study
    """
    # ——— Load original metadata ——————————————————————————————————————
    with open(original_csv, 'r') as f:
        first = f.readline()
    sep = '\t' if '\t' in first else ','
    cols = ['dicom_id','subject_id','study_id','view','count']
    df = pd.read_csv(original_csv, sep=sep, header=None, names=cols, dtype=str)
    meta_studies = set(df['study_id'].unique())
    print(f"✅ Loaded original metadata: {len(df)} rows across {len(meta_studies)} studies.")

    # ——— Scan image_dir for valid study folders —————————————————————
    jpg_paths = glob(os.path.join(image_dir, '**', '*.jpg'), recursive=True)
    img_studies = set()
    for p in jpg_paths:
        folder = os.path.basename(os.path.dirname(p))
        if folder.lower().startswith('s'):  # e.g. s50051329
            img_studies.add(folder[1:])
    print(f"✅ Found {len(img_studies)} study folders with .jpg files in `{image_dir}`.")

    # ——— Missing‐in‐folder / Missing‐in‐metadata reports —————————————
    missing_folder   = sorted(meta_studies - img_studies)
    missing_metadata = sorted(img_studies - meta_studies)
    if missing_folder:
        print(f"⚠️ {len(missing_folder)} studies in metadata but no images.")
        if missing_folder_report:
            pd.DataFrame(missing_folder, columns=['study_id'])\
              .to_csv(missing_folder_report, index=False)
            print(f"  → Saved to {missing_folder_report}")
    if missing_metadata:
        print(f"⚠️ {len(missing_metadata)} image folders with no metadata.")
        if missing_metadata_report:
            pd.DataFrame(missing_metadata, columns=['study_id'])\
              .to_csv(missing_metadata_report, index=False)
            print(f"  → Saved to {missing_metadata_report}")

    # ——— Filter metadata to only those studies with images ———————————
    filtered = df[df['study_id'].isin(img_studies)].copy()
    print(f"✅ Filtered metadata: {len(filtered)} rows across {filtered['study_id'].nunique()} studies.")
    filtered.to_csv(detailed_output_csv, index=False, header=False)
    print(f"  → Wrote filtered detailed metadata to `{detailed_output_csv}`")

    # ——— Build & write per-study summary —————————————————————————
    summary = (
        filtered
        .groupby('study_id')
        .agg(
            total_images_in_study=('dicom_id','count'),
            unique_views_in_study=('view', pd.Series.nunique)
        )
        .reset_index()
    )
    summary.to_csv(summary_output_csv, index=False)
    print(f"✅ Wrote summary metadata to `{summary_output_csv}`")
    print("🎯 All outputs and diagnostics complete.\n")


def generate_experiment_csvs(
    summary_csv:       str,
    full_filtered_csv: str,
    image_root_dir:    str,
    output_dir:        str
):
    """
    1) Load summary metadata (must have columns: study_id, total_images_in_study, unique_views_in_study)
    2) Load filtered detailed metadata (5 cols: dicom_id, subject_id, study_id, view, count)
    3) Re-scan `image_root_dir` to enforce only studies with .jpg
    4) Split into single-view (unique_views==1) vs multi-view (unique_views>=2)
    5) Write single_view.csv & multi_view.csv (no headers)
    """
    # ——— Load inputs —————————————————————————————————————————
    summary_df = pd.read_csv(summary_csv, dtype={'study_id':str})
    detail_df  = pd.read_csv(full_filtered_csv, header=None, dtype=str)
    if detail_df.shape[1] != 5:
        raise ValueError(f"Expected 5 columns in `{full_filtered_csv}`, got {detail_df.shape[1]}")
    detail_df.columns = ['dicom_id','subject_id','study_id','view','count']

    # ——— Re-scan images to ensure only valid studies —————————————
    jpgs = glob(os.path.join(image_root_dir, '**', '*.jpg'), recursive=True)
    valid = set()
    for p in jpgs:
        fld = os.path.basename(os.path.dirname(p))
        if fld.lower().startswith('s'):
            valid.add(fld[1:])
    summary_df = summary_df[summary_df['study_id'].isin(valid)]
    detail_df  = detail_df[ detail_df['study_id'].isin(valid) ]

    # ——— Single-view split —————————————————————————————————————
    single_ids = summary_df.query("unique_views_in_study == 1")['study_id']
    single_df  = detail_df[ detail_df['study_id'].isin(single_ids) ]
    sv_path    = os.path.join(output_dir, 'single_view.csv')
    single_df.to_csv(sv_path, index=False, header=False)
    print(f"✅ Saved single-view CSV: `{sv_path}` ({len(single_df)} rows)")

    # ——— Multi-view split ——————————————————————————————————————
    multi_ids = summary_df.query("unique_views_in_study >= 2")['study_id']
    multi_df  = detail_df[ detail_df['study_id'].isin(multi_ids) ]
    mv_path   = os.path.join(output_dir, 'multi_view.csv')
    multi_df.to_csv(mv_path, index=False, header=False)
    print(f"✅ Saved multi-view CSV: `{mv_path}` ({len(multi_df)} rows)")

    print("🎯 Experiment CSV generation complete.\n")


In [None]:
# 1) Filter + summary
generate_filtered_metadata_with_summary(
    original_csv='metadata/mimiccxr_test_sub_final.csv',
    image_dir='data/images',
    detailed_output_csv='metadata/mimiccxr_test_filtered.csv',
    summary_output_csv='metadata/mimiccxr_test_summary.csv',
    missing_folder_report='metadata/missing_in_folder.csv',
    missing_metadata_report='metadata/missing_in_metadata.csv',
)

# 2) Produce single-/multi-view experiment lists
generate_experiment_csvs(
    summary_csv='metadata/mimiccxr_test_summary.csv',
    full_filtered_csv='metadata/mimiccxr_test_filtered.csv',
    image_root_dir='data/images',
    output_dir='metadata',
)


✅ Loaded original metadata: 4444 rows across 2799 studies.
✅ Found 427 study folders with .jpg files in `data/images`.
⚠️ 2438 studies in metadata but no images.
  → Saved to metadata/missing_in_folder.csv
⚠️ 66 image folders with no metadata.
  → Saved to metadata/missing_in_metadata.csv
✅ Filtered metadata: 526 rows across 361 studies.
  → Wrote filtered detailed metadata to `metadata/mimiccxr_test_filtered.csv`
✅ Wrote summary metadata to `metadata/mimiccxr_test_summary.csv`
🎯 All outputs and diagnostics complete.

✅ Saved single-view CSV: `metadata/single_view.csv` (277 rows)
✅ Saved multi-view CSV: `metadata/multi_view.csv` (249 rows)
🎯 Experiment CSV generation complete.



### Fix ckpt version for compatibility(Optional)

Original ckpt file is trained using older library and to run, it is recommended to update. Below scripts help updating. We provided updated ckpt file. [unixgen_lightning.ckpt](https://drive.google.com/file/d/1LuZXq7DpQUV9cgWTLK6SRvlmSHu_a5E1/view?usp=drive_link) to be placed in /ckpt

In [None]:
!python fix_unixgen_ckpt.py \
       --in_ckpt  ckpt/unixgen.ckpt \
       --out_ckpt ckpt/unixgen_lightning.ckpt

🔧  Loading  ckpt/unixgen.ckpt
    ✔ renamed transformerLM_unified.image_pos_emb.weights_0 → transformerLM_unified.image_pos_emb.weights.0
    ✔ renamed transformerLM_unified.image_pos_emb.weights_1 → transformerLM_unified.image_pos_emb.weights.1

📋 Hyper-parameters stored in ckpt:
{'img_vocab_size': 1024,
 'max_img_num': 3,
 'max_seq_len': 3334,
 'num_img_tokens': 1035,
 'target_count': 3}

✅  Saved Lightning-compatible ckpt →  ckpt/unixgen_lightning.ckpt


In [None]:
!python -m pytorch_lightning.utilities.upgrade_checkpoint ckpt/unixgen_lightning.ckpt

  rank_zero_warn(
100% 1/1 [00:01<00:00,  1.67s/it]


## Sanity Check

### Sanity Check

In [None]:
### Verify tokenizers are correctly loaded and aligning.
import json, os
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing

VOCAB_PATH  = "BBPE_tokenizer/vocab.json"
MERGE_PATH  = "BBPE_tokenizer/merges.txt"
CKPT_PATH   = "ckpt/unixgen.ckpt"

assert os.path.isfile(VOCAB_PATH) and os.path.isfile(MERGE_PATH), "Tokenizer files missing!"
assert os.path.isfile(CKPT_PATH), "Checkpoint file not found!"

# ──── load tokenizer exactly as authors did ───────────────────────────────
tok = ByteLevelBPETokenizer(VOCAB_PATH, MERGE_PATH)
tok.add_special_tokens(["[PAD]", "[SOS]", "[EOS]", "[SEP]", "[MASK]"])
tok._tokenizer.post_processor = BertProcessing(
    ("[EOS]", tok.token_to_id("[EOS]")),
    ("[SOS]", tok.token_to_id("[SOS]")),
)

print("📦  Vocab size:", len(tok.get_vocab()))
print("ID of [PAD] :", tok.token_to_id("[PAD]"))
print("ID of [SOS] :", tok.token_to_id("[SOS]"))
print("ID of [EOS] :", tok.token_to_id("[EOS]"))
print("ID of '.'   :", tok.token_to_id("."))        # easy sanity anchor

# ──── quick decode check ──────────────────────────────────────────────────
sample_ids = list(range(40))     # first 40 token IDs
decoded = tok.decode(sample_ids, skip_special_tokens=True)
print("\n🧪  First 40 IDs decode to:\n", decoded[:200], "…\n")  # truncate print

# ──── quick encode / round-trip check ─────────────────────────────────────
dummy = "No pneumothorax. The heart is mildly enlarged."
ids   = tok.encode(dummy).ids
roundtrip = tok.decode(ids, skip_special_tokens=True)
print("Round-trip ok? ->", roundtrip == dummy)


📦  Vocab size: 14526
ID of [PAD] : 0
ID of [SOS] : 1
ID of [EOS] : 2
ID of '.'   : 18

🧪  First 40 IDs decode to:
 !"#$%&'()*+,-./0123456789:;<=>?@ABC …

Round-trip ok? -> True


### Memory Check

In [None]:
import torch
torch.set_float32_matmul_precision('high')

In [None]:
!nvidia-smi

Mon May  5 16:07:59 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   36C    P8             11W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import torch
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [None]:
import torch
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9} GB")
print(f"Memory reserved: {torch.cuda.memory_reserved() / 1e9} GB")
torch.cuda.memory_summary()


Memory allocated: 0.0 GB
Memory reserved: 0.0 GB




In [None]:
!nproc

12


## Run

### For Training

In [None]:
# %cd {PRJ_ROOT}
# !python unified_main.py --batch_size 10 --num_workers 8

### Test

In [None]:
#check stages
%cd {PRJ_ROOT}
!python unified_run.py

### Ablation

In [None]:
%cd {PRJ_ROOT}
!python unified_run_ablation.py

### Decode

#### Quick Preview output

In [None]:

import torch

# Load the saved test outputs
test_outputs = torch.load('output/<--Your--File--Name.pt-->', map_location='cpu')

# Check how many
print(f"✅ Loaded {len(test_outputs)} test outputs.")

# Let's preview first 3 reports
for i in range(min(3, len(test_outputs))):
    gt_tokens = test_outputs[i]['GT_text']
    gen_tokens = test_outputs[i]['gen_text']

    from tokenizers import ByteLevelBPETokenizer
    tokenizer = ByteLevelBPETokenizer('BBPE_tokenizer/vocab.json', 'BBPE_tokenizer/merges.txt')
    tokenizer.add_special_tokens(["[PAD]", "[SOS]", "[EOS]", "[SEP]", "[MASK]"])

    gt_report = tokenizer.decode(gt_tokens[0].tolist(), skip_special_tokens=True)
    gen_report = tokenizer.decode(gen_tokens[0].tolist(), skip_special_tokens=True)

    print(f"\n=== Sample {i+1} ===")
    print(f"Ground Truth:\n{gt_report}")
    print(f"Generated:\n{gen_report}")


#### Decode Images

In [None]:
!python decode_cxr.py \
  --test_output_glob="./output/test_output_*.pt" \
  --save_dir="./output/decoded_images/" \
  --vqgan_model_path="./mimiccxr_vqgan/last.ckpt" \
  --vqgan_config_path="./mimiccxr_vqgan/2021-12-17T08-58-54-project.yaml" \
  --img_save=True \
  --preview=True


#### Decode Reports

In [None]:
!python decode_report.py \
    --test_output_dir ./output \
    --tokenizer_dir BBPE_tokenizer \
    --save_csv \
    --save_dir ./output/decoded_reports


### Metrics

#### FID

In [None]:
!python fid.py --gt_path ./output/decoded_images \
              --batch-size 32 --dims 1024 --num-workers 8


#### BLEU / BERT

In [None]:
!python evaluate_outputs.py \
    --decoded_glob "./output/decoded_reports/*_GT_vs_GEN.csv" \
    --output_csv "./output/eval_summary.csv"