diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..df3ff2164 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -238,6 +238,7 @@ Available Datasets datasets/pyhealth.datasets.BMDHSDataset datasets/pyhealth.datasets.COVID19CXRDataset datasets/pyhealth.datasets.ChestXray14Dataset + datasets/pyhealth.datasets.VQARADDataset datasets/pyhealth.datasets.TUABDataset datasets/pyhealth.datasets.TUEVDataset datasets/pyhealth.datasets.ClinVarDataset diff --git a/docs/api/datasets/pyhealth.datasets.VQARADDataset.rst b/docs/api/datasets/pyhealth.datasets.VQARADDataset.rst new file mode 100644 index 000000000..d38986dc5 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.VQARADDataset.rst @@ -0,0 +1,11 @@ +pyhealth.datasets.VQARADDataset +=================================== + +The VQA-RAD dataset for medical visual question answering. The dataset loader +converts the public JSON annotations into a flat metadata CSV that PyHealth can +ingest, and its default task is :class:`~pyhealth.tasks.MedicalVQATask`. + +.. autoclass:: pyhealth.datasets.VQARADDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models/pyhealth.models.MedFlamingo.rst b/docs/api/models/pyhealth.models.MedFlamingo.rst index 7f782d0e3..a0f2475d9 100644 --- a/docs/api/models/pyhealth.models.MedFlamingo.rst +++ b/docs/api/models/pyhealth.models.MedFlamingo.rst @@ -3,21 +3,37 @@ pyhealth.models.MedFlamingo MedFlamingo: multimodal medical few-shot learner. -The separate callable MedFlamingoLayer (gated cross-attention dense block) -and the complete MedFlamingo model. +This reference covers the visual resampler, the gated cross-attention +building block, and the complete MedFlamingo model used in the VQA-RAD +integration branch. **Paper:** Moor et al. "Med-Flamingo: a Multimodal Medical Few-shot Learner" ML4H 2023. .. note:: - This is a stub implementation. The class structure and signatures are - in place, but ``forward()`` and ``generate()`` raise ``NotImplementedError``. + ``forward()`` follows the PyHealth training contract for dataset-backed + classification-style use, while ``generate()`` provides the multimodal + prompting path for direct medical VQA generation. -.. autoclass:: pyhealth.models.MedFlamingoLayer +PerceiverResampler +------------------ + +.. autoclass:: pyhealth.models.medflamingo.PerceiverResampler :members: :undoc-members: :show-inheritance: +MedFlamingoLayer +---------------- + +.. autoclass:: pyhealth.models.medflamingo.MedFlamingoLayer + :members: + :undoc-members: + :show-inheritance: + +MedFlamingo +----------- + .. autoclass:: pyhealth.models.MedFlamingo :members: :undoc-members: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..b1aaf74fd 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -213,6 +213,7 @@ Available Tasks DKA Prediction (MIMIC-IV) Drug Recommendation Length of Stay Prediction + Medical VQA Medical Transcriptions Classification Mortality Prediction (Next Visit) Mortality Prediction (StageNet MIMIC-IV) diff --git a/docs/api/tasks/pyhealth.tasks.MedicalVQATask.rst b/docs/api/tasks/pyhealth.tasks.MedicalVQATask.rst new file mode 100644 index 000000000..4221d6ab3 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.MedicalVQATask.rst @@ -0,0 +1,12 @@ +pyhealth.tasks.MedicalVQATask +=================================== + +Medical visual question answering task for paired radiology images and +questions. This task treats VQA-RAD answers as a multiclass prediction target +so the resulting ``SampleDataset`` can be trained with the standard PyHealth +trainer loop. + +.. autoclass:: pyhealth.tasks.MedicalVQATask + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/vqarad_medvqa_medflamingo.py b/examples/vqarad_medvqa_medflamingo.py new file mode 100644 index 000000000..2ff4d4b4a --- /dev/null +++ b/examples/vqarad_medvqa_medflamingo.py @@ -0,0 +1,375 @@ +"""End-to-end VQA-RAD MedFlamingo pipeline with ablation study. + +This script demonstrates the complete PyHealth pipeline for the MedFlamingo +model on the VQA-RAD medical visual question answering dataset: + +1. Load the VQA-RAD base dataset +2. Apply ``MedicalVQATask`` via ``set_task()`` +3. Split into train / validation / test sets +4. Create dataloaders +5. Train ``MedFlamingo`` with ``Trainer.train()`` +6. Evaluate with ``Trainer.evaluate()`` +7. Run a compact few-shot generation example +8. **Ablation study** comparing three independent axes: + - Cross-attention density (``cross_attn_every_n_layers`` in {1, 2, 4}) + - Perceiver resampler size (``num_resampler_tokens`` in {16, 32, 64}) + - Frozen vs. fine-tunable vision encoder (``freeze_vision`` in {True, False}) + +Ablation motivation: + MedFlamingo's core design choices are (1) how densely to interleave + cross-attention layers between vision and language, (2) how many latent + tokens the Perceiver Resampler compresses visual features into, and (3) + whether the frozen CLIP backbone benefits from end-to-end fine-tuning on + the downstream VQA task. The three ablation axes isolate each variable + while holding the others at the paper's default. + +Usage:: + + # Baseline only (fast): + python examples/vqarad_medvqa_medflamingo.py --root /path/to/vqarad + + # With full ablation study (slower; runs 7 training trials): + python examples/vqarad_medvqa_medflamingo.py --root /path/to/vqarad --ablation + +Note: + The default ``MedFlamingo`` constructor downloads large Hugging Face + weights (CLIP ViT-L/14, OPT-6.7B) on first run, which requires + substantial disk space and memory. For fast local testing without + downloading weights, replace ``MedFlamingo`` with the + ``TestableMedFlamingo`` stub from ``tests/core/test_medflamingo.py``. +""" + +from __future__ import annotations + +import argparse +from typing import Dict, List + +from pyhealth.datasets import ( + VQARADDataset, + get_dataloader, + split_by_patient, + split_by_sample, +) +from pyhealth.models import MedFlamingo +from pyhealth.trainer import Trainer + + +# --------------------------------------------------------------------------- +# Helper utilities +# --------------------------------------------------------------------------- + + +def choose_splitter(samples): + """Prefer patient-level splitting when the sample dataset preserves it.""" + patient_to_index = getattr(samples, "patient_to_index", {}) + if patient_to_index: + return split_by_patient, "patient" + return split_by_sample, "sample" + + +def build_few_shot_text(sample: dict) -> str: + """Formats one processed sample as a simple in-context example.""" + return f"Q: {sample['question']}\nA: {sample['answer']}" + + +# --------------------------------------------------------------------------- +# Ablation helpers +# --------------------------------------------------------------------------- + + +def _run_one_config( + samples, + train_ds, + val_ds, + test_ds, + *, + cross_attn_every_n_layers: int, + num_resampler_tokens: int, + freeze_vision: bool, + batch_size: int, + epochs: int, +) -> Dict[str, float]: + """Train and evaluate MedFlamingo for one ablation configuration. + + Args: + samples: The full :class:`~pyhealth.datasets.SampleDataset` used to + configure the model (vocabulary size, feature keys, etc.). + train_ds: Training split. + val_ds: Validation split. + test_ds: Test split. + cross_attn_every_n_layers: How often to insert a gated cross-attention + dense block. Smaller values mean denser vision-language interaction. + num_resampler_tokens: Number of fixed-length visual tokens produced by + the Perceiver Resampler. + freeze_vision: Whether to freeze the CLIP vision encoder weights. + batch_size: DataLoader batch size. + epochs: Number of training epochs. + + Returns: + Dict with keys ``val_accuracy``, ``val_loss``, ``test_accuracy``, and + ``test_loss`` for this configuration. + """ + train_loader = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=batch_size, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=batch_size, shuffle=False) + + model = MedFlamingo( + dataset=samples, + cross_attn_every_n_layers=cross_attn_every_n_layers, + num_resampler_tokens=num_resampler_tokens, + freeze_vision=freeze_vision, + ) + + trainer = Trainer(model=model, metrics=["accuracy", "f1_macro"]) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + ) + + val_scores = trainer.evaluate(val_loader) + test_scores = trainer.evaluate(test_loader) + + return { + "val_accuracy": val_scores.get("accuracy", float("nan")), + "val_loss": val_scores.get("loss", float("nan")), + "test_accuracy": test_scores.get("accuracy", float("nan")), + "test_loss": test_scores.get("loss", float("nan")), + } + + +def _print_results_table(rows: List[dict], title: str) -> None: + """Print a formatted results table for the ablation study. + + Args: + rows: List of dicts, each containing ``config`` and four metric keys. + title: Title printed above the table. + """ + print(f"\n{'=' * 72}") + print(f" {title}") + print(f"{'=' * 72}") + header = ( + f"{'Config':<36} {'Val Acc':>8} {'Val Loss':>9}" + f" {'Test Acc':>9} {'Test Loss':>10}" + ) + print(header) + print("-" * 72) + for row in rows: + print( + f"{row['config']:<36}" + f" {row['val_accuracy']:>8.4f}" + f" {row['val_loss']:>9.4f}" + f" {row['test_accuracy']:>9.4f}" + f" {row['test_loss']:>10.4f}" + ) + print("=" * 72) + + +# --------------------------------------------------------------------------- +# Argument parsing +# --------------------------------------------------------------------------- + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments. + + Returns: + Parsed argument namespace. + """ + parser = argparse.ArgumentParser( + description="Train MedFlamingo on VQA-RAD with optional ablation study" + ) + parser.add_argument("--root", required=True, help="Path to the VQA-RAD root") + parser.add_argument( + "--cache-dir", + default=None, + help="Optional cache directory for processed dataset artifacts", + ) + parser.add_argument("--dataset-num-workers", type=int, default=1) + parser.add_argument("--task-num-workers", type=int, default=1) + parser.add_argument("--batch-size", type=int, default=2) + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--max-new-tokens", type=int, default=32) + parser.add_argument( + "--ablation", + action="store_true", + help=( + "Run full ablation study across cross_attn_every_n_layers, " + "num_resampler_tokens, and freeze_vision (runs 7 training trials)." + ), + ) + return parser.parse_args() + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + args = parse_args() + + # ------------------------------------------------------------------ + # Step 1 – Load dataset + # ------------------------------------------------------------------ + dataset = VQARADDataset( + root=args.root, + cache_dir=args.cache_dir, + num_workers=args.dataset_num_workers, + ) + dataset.stats() + + # ------------------------------------------------------------------ + # Step 2 – Apply task + # ------------------------------------------------------------------ + task_samples = dataset.set_task(num_workers=args.task_num_workers) + + # ------------------------------------------------------------------ + # Step 3 – Split + # ------------------------------------------------------------------ + splitter, split_name = choose_splitter(task_samples) + print(f"Using {split_name}-level split") + train_dataset, val_dataset, test_dataset = splitter( + task_samples, + [0.7, 0.1, 0.2], + seed=42, + ) + + # ------------------------------------------------------------------ + # Steps 4-6 – Baseline training run (default hyperparameters) + # cross_attn_every_n_layers=4, num_resampler_tokens=64, freeze_vision=True + # ------------------------------------------------------------------ + print("\n=== Baseline (xattn_every=4, tokens=64, frozen_vision=True) ===") + train_loader = get_dataloader( + train_dataset, batch_size=args.batch_size, shuffle=True + ) + val_loader = get_dataloader(val_dataset, batch_size=args.batch_size, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=args.batch_size, shuffle=False) + + model = MedFlamingo(dataset=task_samples) + trainer = Trainer(model=model, metrics=["accuracy", "f1_macro"]) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=args.epochs, + ) + + test_metrics = trainer.evaluate(test_loader) + print("Baseline test metrics:", test_metrics) + + # ------------------------------------------------------------------ + # Step 7 – Few-shot generation example + # ------------------------------------------------------------------ + query_sample = test_dataset[0] + context_sample = train_dataset[0] + generation = model.generate( + images=[query_sample["image"]], + prompt=query_sample["question"], + few_shot_examples=[ + { + "image": context_sample["image"], + "text": build_few_shot_text(context_sample), + } + ], + max_new_tokens=args.max_new_tokens, + ) + print("Few-shot generation:", generation) + + # ------------------------------------------------------------------ + # Step 8 – Ablation study + # + # Three independent axes are studied: + # + # A) Cross-attention density (cross_attn_every_n_layers ∈ {1, 2, 4}) + # More frequent cross-attention inserts more vision-language bridges + # into the frozen LLM stack. The paper uses every 4th layer; denser + # insertion trades compute for richer multimodal grounding. + # + # B) Perceiver Resampler capacity (num_resampler_tokens ∈ {16, 32, 64}) + # The resampler maps raw CLIP patch tokens to a fixed-length sequence. + # Fewer tokens are cheaper but may lose spatial detail; more tokens + # preserve finer-grained visual information. + # + # C) Vision encoder fine-tuning (freeze_vision ∈ {True, False}) + # The original Flamingo/MedFlamingo paper freezes CLIP to preserve its + # pretrained representations. Unfreezing allows CLIP to adapt to + # medical imagery but risks overfitting on small datasets. + # + # All ablations use a single training epoch for speed; increase --epochs + # for more reliable comparisons. + # ------------------------------------------------------------------ + if args.ablation: + print("\n\n" + "#" * 72) + print("# ABLATION STUDY") + print("#" * 72) + + # ---- Ablation A: cross_attn_every_n_layers ---- + xattn_results = [] + for n in [1, 2, 4]: + print(f"\n--- Ablation A: cross_attn_every_n_layers={n} ---") + scores = _run_one_config( + task_samples, + train_dataset, + val_dataset, + test_dataset, + cross_attn_every_n_layers=n, + num_resampler_tokens=64, # default + freeze_vision=True, # default + batch_size=args.batch_size, + epochs=args.epochs, + ) + xattn_results.append({"config": f"xattn_every={n}", **scores}) + _print_results_table( + xattn_results, + "Ablation A: cross_attn_every_n_layers" + " (tokens=64, frozen_vision=True)", + ) + + # ---- Ablation B: num_resampler_tokens ---- + token_results = [] + for t in [16, 32, 64]: + print(f"\n--- Ablation B: num_resampler_tokens={t} ---") + scores = _run_one_config( + task_samples, + train_dataset, + val_dataset, + test_dataset, + cross_attn_every_n_layers=4, # default + num_resampler_tokens=t, + freeze_vision=True, # default + batch_size=args.batch_size, + epochs=args.epochs, + ) + token_results.append({"config": f"resampler_tokens={t}", **scores}) + _print_results_table( + token_results, + "Ablation B: num_resampler_tokens" + " (xattn_every=4, frozen_vision=True)", + ) + + # ---- Ablation C: freeze_vision ---- + freeze_results = [] + for fv in [True, False]: + label = "frozen" if fv else "fine-tuned" + print(f"\n--- Ablation C: freeze_vision={fv} ({label}) ---") + scores = _run_one_config( + task_samples, + train_dataset, + val_dataset, + test_dataset, + cross_attn_every_n_layers=4, # default + num_resampler_tokens=64, # default + freeze_vision=fv, + batch_size=args.batch_size, + epochs=args.epochs, + ) + freeze_results.append({"config": f"vision_{label}", **scores}) + _print_results_table( + freeze_results, + "Ablation C: freeze_vision" + " (xattn_every=4, resampler_tokens=64)", + ) + + print("\nAblation study complete.") + + task_samples.close() diff --git a/pixi.lock b/pixi.lock index 0f11d28d7..d761e3e60 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2224,6 +2224,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8c095d6_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_hd72426e_102.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda + - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl @@ -2240,6 +2241,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/75/b4/b96bb66f6f8cc4669de44a158099b249c8159231d254ab6b092909388be5/fonttools-4.59.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl @@ -2269,6 +2271,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/19/49/4df9123aafa7b539317bf6d342cb6d227e49f7a35b99c287a6109b13dd93/numpy-2.2.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/49/60/7b6497946d74bcf1de852a21824d63baad12cd417db4195fc1bfe59db953/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl @@ -2308,6 +2311,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/34/43/3f250ec28edff1c06ffaa25faddbe13ae85c11a9724894cbdcf89427de78/rdkit-2025.3.3-cp313-cp313-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/db/60/1eeca2074f5b87df394fccaa432ae3fc06c9c9bfa97c5051aed70e6e00c2/regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz - pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a6/f8/dae3421624fcc87a89d42e1898a798bc7ff72c61f38973a65d60df8f124c/safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/99/72/c86a4cd867816350fe8dee13f30222340b9cd6b96173955819a5561810c5/scikit_learn-1.7.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl @@ -2360,6 +2364,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/readline-8.2-h8382b9d_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/tk-8.6.13-noxft_h5688188_102.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda + - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl @@ -2376,6 +2381,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b5/57/7969af50b26408be12baa317c6147588db5b38af2759e6df94554dbc5fdb/fonttools-4.59.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl @@ -2405,6 +2411,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/d3/68/93180dce57f684a61a88a45ed13047558ded2be46f03acb8dec6d7c513af/msgpack-1.1.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl - pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/85/c5/e19c8f99d83fd377ec8c7e0cf627a8049746da54afc24ef0a0cb73d5dfb5/numpy-2.2.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl @@ -2430,6 +2437,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/ff/5f/907a48c5f9b83302b4530605df1325963977fdf06753d3d8610d16c40197/rdkit-2025.3.3-cp313-cp313-manylinux_2_28_aarch64.whl - pypi: https://files.pythonhosted.org/packages/fc/fd/37868b75eaf63843165f1d2122ca6cb94bfc0271e4428cf58c0616786dce/regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz - pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5d/9a/add3e6fef267658075c5a41573c26d42d80c935cdc992384dfae435feaef/safetensors-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl - pypi: https://files.pythonhosted.org/packages/e8/66/277967b29bd297538dc7a6ecfb1a7dce751beabd0d7f7a2233be7a4f7832/scikit_learn-1.7.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl @@ -2472,6 +2480,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/readline-8.2-h1d1bf99_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tk-8.6.13-h892fb3f_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda + - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl @@ -2488,6 +2497,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f3/bb/390990e7c457d377b00890d9f96a3ca13ae2517efafb6609c1756e213ba4/fonttools-4.59.0-cp313-cp313-macosx_10_13_universal2.whl @@ -2517,6 +2527,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/92/dc/c385f38f2c2433333345a82926c6bfa5ecfff3ef787201614317b58dd8be/msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/dc/9e/14520dc3dadf3c803473bd07e9b2bd1b69bc583cb2497b47000fed2fa92f/numpy-2.2.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl @@ -2542,6 +2553,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/3b/0b/6ab0cc692b2890f4f7c74f6ffd4bba748dcb9312d5a7bd2328cb82204da1/rdkit-2025.3.3-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/09/c9/4e68181a4a652fb3ef5099e077faf4fd2a694ea6e0f806a7737aff9e758a/regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz - pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b8/3b/11f1b4a2f5d2ab7da34ecc062b0bc301f2be024d110a6466726bec8c055c/safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/71/f3/f1df377d1bdfc3e3e2adc9c119c238b182293e6740df4cbeac6de2cc3e23/scikit_learn-1.7.1-cp313-cp313-macosx_12_0_arm64.whl @@ -2585,6 +2597,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/vc-14.3-h41ae7f8_26.conda - conda: https://conda.anaconda.org/conda-forge/win-64/vc14_runtime-14.44.35208-h818238b_26.conda + - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e7/f9/25753b9de3029d3eb2487755520b98eb72b0cb562d8974329c6e19831063/axial_positional_embedding-0.3.12-py3-none-any.whl @@ -2602,6 +2615,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/1d/54/a46920229d12c3a6e9f0081d1bdaeffad23c1826353ace95714faee926e5/dask-2025.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/ec/da78855318971c2be94d0283a41de6941a6b9f16146fb00babc74903ae01/distributed-2025.11.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a0/ee/f626cd372932d828508137a79b85167fdcf3adab2e3bed433f295c596c6a/fonttools-4.59.0-cp313-cp313-win_amd64.whl @@ -2630,6 +2644,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/74/07/1ed8277f8653c40ebc65985180b007879f6a836c525b3885dcc6448ae6cb/msgpack-1.1.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/cb/3b/d58c12eafcb298d4e6d0d40216866ab15f59e55d148a5658bb3132311fcf/numpy-2.2.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7e/95/e0770cf1ad9667492f56b732f44398ef2756d61df914e10d121a3cad013a/ogb-1.3.6-py3-none-any.whl @@ -2655,6 +2670,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/98/da/164e31b607c0cf22f1179cd15fa058780f940b21ec42ba3c9026c21897e3/rdkit-2025.3.3-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/45/94/bc295babb3062a731f52621cdc992d123111282e291abaf23faa413443ea/regex-2024.11.6-cp313-cp313-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz - pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/69/e2/b011c38e5394c4c18fb5500778a55ec43ad6106126e74723ffaee246f56e/safetensors-0.5.3-cp38-abi3-win_amd64.whl - pypi: https://files.pythonhosted.org/packages/e2/47/9291cfa1db1dae9880420d1e07dbc7e8dd4a7cdbc42eaba22512e6bde958/scikit_learn-1.7.1-cp313-cp313-win_amd64.whl @@ -3213,6 +3229,11 @@ packages: purls: [] size: 8191 timestamp: 1744137672556 +- pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl + name: absl-py + version: 2.4.0 + sha256: 88476fd881ca8aab94ffa78b7b6c632a782ab3ba1cd19c9bd423abc4fb4cd28d + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/30/dd/0107f0aa179869ee9f47ef5a2686abd5e022fdc82af901d535e52fe91ce1/accelerate-1.10.0-py3-none-any.whl name: accelerate version: 1.10.0 @@ -3958,6 +3979,11 @@ packages: - pkg:pypi/editables?source=hash-mapping size: 10828 timestamp: 1733208220327 +- pypi: https://files.pythonhosted.org/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz + name: editdistance + version: 0.8.1 + sha256: d1cdf80a5d5014b0c9126a69a42ce55a457b457f6986ff69ca98e4fe4d2d8fed + requires_python: '>=3.8' - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl name: einops version: 0.8.2 @@ -5913,6 +5939,32 @@ packages: - pkg:pypi/nh3?source=hash-mapping size: 584955 timestamp: 1756737407424 +- pypi: https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl + name: nltk + version: 3.9.4 + sha256: f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f + requires_dist: + - click + - joblib + - regex>=2021.8.3 + - tqdm + - numpy ; extra == 'machine-learning' + - python-crfsuite ; extra == 'machine-learning' + - scikit-learn ; extra == 'machine-learning' + - scipy ; extra == 'machine-learning' + - matplotlib ; extra == 'plot' + - pyparsing ; extra == 'tgrep' + - twython ; extra == 'twitter' + - requests ; extra == 'corenlp' + - scipy ; extra == 'all' + - python-crfsuite ; extra == 'all' + - pyparsing ; extra == 'all' + - requests ; extra == 'all' + - numpy ; extra == 'all' + - scikit-learn ; extra == 'all' + - twython ; extra == 'all' + - matplotlib ; extra == 'all' + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/19/49/4df9123aafa7b539317bf6d342cb6d227e49f7a35b99c287a6109b13dd93/numpy-2.2.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl name: numpy version: 2.2.6 @@ -7030,7 +7082,7 @@ packages: - pypi: ./ name: pyhealth version: 2.0.0 - sha256: f07719f9dceb759c35507216c8033d2f915d241418d4fad2ab51b37c0e73260f + sha256: 13848208817fed7588e7fd4d5d8b66a5f89c3aeded10a9381dff177d4c790edf requires_dist: - torch~=2.7.1 - torchvision @@ -7055,6 +7107,10 @@ packages: - more-itertools~=10.8.0 - einops>=0.8.0 - linear-attention-transformer>=0.19.1 + - torch-geometric>=2.6.0 ; extra == 'graph' + - editdistance~=0.8.1 ; extra == 'nlp' + - rouge-score~=0.1.2 ; extra == 'nlp' + - nltk~=3.9.1 ; extra == 'nlp' requires_python: '>=3.12,<3.14' - pypi: https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl name: pyparsing @@ -7416,6 +7472,16 @@ packages: - pkg:pypi/rich?source=compressed-mapping size: 201098 timestamp: 1753436991345 +- pypi: https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz + name: rouge-score + version: 0.1.2 + sha256: c7d4da2683e68c9abf0135ef915d63a46643666f848e558a1b9f7ead17ff0f04 + requires_dist: + - absl-py + - nltk + - numpy + - six>=1.14.0 + requires_python: '>=3.7' - pypi: https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl name: s3transfer version: 0.16.0 diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 7400d20cb..f80193b00 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -67,6 +67,7 @@ def __init__(self, *args, **kwargs): from .bmd_hs import BMDHSDataset from .support2 import Support2Dataset from .tcga_prad import TCGAPRADDataset +from .vqarad import VQARADDataset from .splitter import ( sample_balanced, split_by_patient, @@ -82,7 +83,6 @@ def __init__(self, *args, **kwargs): ) from .tuab import TUABDataset from .tuev import TUEVDataset -from .vqarad import VQARADDataset from .utils import ( collate_fn_dict, collate_fn_dict_with_padding, diff --git a/pyhealth/datasets/vqarad.py b/pyhealth/datasets/vqarad.py index f2de429b1..44af00c31 100644 --- a/pyhealth/datasets/vqarad.py +++ b/pyhealth/datasets/vqarad.py @@ -17,6 +17,11 @@ root/ VQA_RAD Dataset Public.json + The official OSF archive may keep images in ``VQA_RAD Image Folder/`` + rather than ``images/``. This loader accepts either layout and rewrites + the raw export into ``vqarad-metadata-pyhealth.csv`` for the standard + PyHealth pipeline. + Citation: Lau, J. J., Gayen, S., Ben Abacha, A., & Demner-Fushman, D. (2018). A dataset of clinically generated visual questions and answers about @@ -26,15 +31,15 @@ import json import logging import os +from functools import wraps from pathlib import Path -from typing import Dict, Optional +from typing import Optional import pandas as pd from pyhealth.datasets.sample_dataset import SampleDataset from pyhealth.processors.base_processor import FeatureProcessor from pyhealth.processors.image_processor import ImageProcessor -from pyhealth.tasks.base_task import BaseTask from ..tasks import MedicalVQATask from .base_dataset import BaseDataset @@ -51,8 +56,9 @@ class VQARADDataset(BaseDataset): Args: root: Root directory containing the VQA-RAD data files. - Expected to contain ``VQA_RAD Dataset Public.json`` and an - ``images/`` subdirectory with the radiology images. + Expected to contain ``VQA_RAD Dataset Public.json`` and either + an ``images/`` subdirectory or the original OSF + ``VQA_RAD Image Folder/`` directory with the radiology images. dataset_name: Optional name. Defaults to ``"vqarad"``. config_path: Optional path to a YAML config. If ``None``, uses the bundled ``configs/vqarad.yaml``. @@ -99,9 +105,11 @@ def __init__( def prepare_metadata(self, root: str) -> None: """Convert the raw VQA-RAD JSON into a flat CSV. - The JSON file contains a list of QA entries, each with fields like - ``"IMAGES_PATH"``, ``"QUESTION"``, ``"ANSWER"``, etc. This method - normalises them into a CSV with columns matching the YAML config. + The raw VQA-RAD export may come from different mirrors. This method + accepts both the original OSF field names (for example + ``image_name``, ``question``, ``answer``) and alternate uppercase + field names (for example ``IMAGE_PATH``, ``QUESTION``, ``ANSWER``), + then normalizes them into a CSV with columns matching the YAML config. Args: root: Root directory containing ``VQA_RAD Dataset Public.json``. @@ -116,26 +124,67 @@ def prepare_metadata(self, root: str) -> None: with open(json_path, "r") as f: data = json.load(f) + image_root = self._resolve_image_root(root) rows = [] for entry in data: - image_name = entry.get("IMAGE_PATH", entry.get("IMAGES_PATH", "")) - image_path = os.path.join(root, "images", image_name) + image_name = ( + entry.get("IMAGE_PATH") + or entry.get("IMAGES_PATH") + or entry.get("image_name") + or "" + ) + image_path = os.path.join(image_root, image_name) if image_name else "" rows.append( { "image_path": image_path, - "question": entry.get("QUESTION", ""), - "answer": str(entry.get("ANSWER", "")), - "answer_type": entry.get("ANSWER_TYPE", ""), - "question_type": entry.get("QUESTION_TYPE", ""), - "image_organ": entry.get("IMAGE_ORGAN", ""), + "question": entry.get("QUESTION", entry.get("question", "")), + "answer": str(entry.get("ANSWER", entry.get("answer", ""))), + "answer_type": entry.get( + "ANSWER_TYPE", entry.get("answer_type", "") + ), + "question_type": entry.get( + "QUESTION_TYPE", entry.get("question_type", "") + ), + "image_organ": entry.get( + "IMAGE_ORGAN", entry.get("image_organ", "") + ), } ) df = pd.DataFrame(rows) + + # Filter out rows whose image file is missing so that the processor + # pipeline does not fail on incomplete dataset downloads. + before = len(df) + df = df[df["image_path"].apply(lambda p: bool(p) and os.path.isfile(p))] + skipped = before - len(df) + if skipped: + logger.warning( + f"Skipped {skipped} entries with missing image files " + f"(out of {before} total)." + ) + out_path = os.path.join(root, "vqarad-metadata-pyhealth.csv") df.to_csv(out_path, index=False) logger.info(f"Saved VQA-RAD metadata ({len(df)} rows) to {out_path}") + @staticmethod + def _resolve_image_root(root: str) -> str: + """Finds the VQA-RAD image directory for the supported raw layouts.""" + candidate_dirs = [ + os.path.join(root, "images"), + os.path.join(root, "VQA_RAD Image Folder"), + ] + + for candidate in candidate_dirs: + if os.path.isdir(candidate): + return candidate + + raise FileNotFoundError( + "Expected VQA-RAD images in either " + f"{candidate_dirs[0]} or {candidate_dirs[1]}." + ) + @property def default_task(self) -> MedicalVQATask: """Returns the default task for this dataset. @@ -145,34 +194,24 @@ def default_task(self) -> MedicalVQATask: """ return MedicalVQATask() + @wraps(BaseDataset.set_task) def set_task( self, - task: Optional[BaseTask] = None, + *args, image_processor: Optional[FeatureProcessor] = None, **kwargs, ) -> SampleDataset: - """Set a task and return a :class:`SampleDataset`. + """Set a task and inject the default image processor when needed.""" + input_processors = kwargs.get("input_processors", None) - If no ``image_processor`` is provided, defaults to - :class:`~pyhealth.processors.ImageProcessor` with ``mode="RGB"`` - and ``image_size=224`` (matching CLIP ViT input). - - Args: - task: A task instance. Defaults to :meth:`default_task`. - image_processor: Optional custom image processor. - **kwargs: Passed to :meth:`BaseDataset.set_task`. - - Returns: - A :class:`SampleDataset` ready for model training. - """ - if task is None: - task = self.default_task + if input_processors is None: + input_processors = {} if image_processor is None: image_processor = ImageProcessor(mode="RGB", image_size=224) - return super().set_task( - task, - image_processor=image_processor, - **kwargs, - ) + if "image" not in input_processors: + input_processors["image"] = image_processor + + kwargs["input_processors"] = input_processors + return super().set_task(*args, **kwargs) diff --git a/pyhealth/models/medflamingo.py b/pyhealth/models/medflamingo.py index f53106762..540cceffd 100644 --- a/pyhealth/models/medflamingo.py +++ b/pyhealth/models/medflamingo.py @@ -27,9 +27,10 @@ - MedFlamingo checkpoint: consult the original repository for terms Note: - This is a stub implementation. Class structure, signatures, and - docstrings are in place, but ``forward()`` and ``generate()`` raise - ``NotImplementedError``. Full implementation is forthcoming. + This implementation exposes both ``forward()`` for PyHealth training + loops and ``generate()`` for direct multimodal prompting. The default + constructor still relies on heavyweight pretrained backbones, so the + first run may download substantial Hugging Face assets. """ from typing import Any, Dict, List, Optional, Tuple @@ -330,10 +331,10 @@ class MedFlamingo(BaseModel): - MedFlamingo checkpoint: see https://github.com/snap-stanford/med-flamingo Note: - This is a stub implementation. ``forward()`` and ``generate()`` - raise ``NotImplementedError``. Heavy dependencies (open_flamingo, - CLIP, LLM weights) will use lazy imports to avoid multi-GB - downloads at import time. + ``forward()`` implements the PyHealth classification-style contract + for dataset-backed usage, while ``generate()`` provides the native + multimodal prompting interface. The default constructor lazily loads + large pretrained dependencies the first time the model is created. Args: dataset: A :class:`~pyhealth.datasets.SampleDataset`, or ``None`` @@ -389,6 +390,7 @@ def __init__( # If a dataset is provided with a single label, prepare for # classification (VQA-as-multiclass). + self._fc = None # default; overridden below when dataset is available if dataset is not None and len(self.label_keys) == 1: self.label_key = self.label_keys[0] self._init_classification_head() @@ -693,42 +695,44 @@ def generate( text_embeds = self._lang_model.model.embed_tokens(encoded_context["input_ids"]) # (1, seq_len, lang_dim) - # Step 4: Apply cross-attention for conditioning + # Step 4: Apply cross-attention to produce visually-conditioned embeddings lang_hidden = text_embeds - - # Use all accumulated vision features for conditioning - # For simplicity, concatenate all vision features - all_vision_features = torch.cat(vision_features_list, dim=1) # (batch_size, total_patches, vision_dim) - + + # Concatenate all vision features (few-shot images + query image) + all_vision_features = torch.cat( + vision_features_list, dim=1 + ) # (1, total_patches, vision_dim) + for xattn_layer in self._xattn_layers: - lang_hidden = xattn_layer(lang_hidden, all_vision_features[:1]) # Use first batch's features for single sample - - # Step 5: Prepare input for generation - # Reuse the encoded input IDs but with updated hidden states - input_ids = encoded_context["input_ids"] + lang_hidden = xattn_layer( + lang_hidden, all_vision_features[:1] + ) # use first (and only) batch element + + # Step 5: Generate from the conditioned embeddings. + # Pass ``inputs_embeds`` so the LLM starts from the xattn-conditioned + # representations rather than the raw token embeddings. The + # attention_mask from the tokenizer still applies; a new all-ones mask + # matching the embedding sequence length is used if none is available. attention_mask = encoded_context.get("attention_mask") - - # Step 6: Generate using the language model - # We'll craft the generation call to use the conditioned embeddings + with torch.no_grad(): - # Generate from the LLM conditioned on visual features output = self._lang_model.generate( - input_ids=input_ids, + inputs_embeds=lang_hidden, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=(temperature > 1.0), - **generation_kwargs + **generation_kwargs, ) - - # Step 7: Decode generated tokens + + # Step 6: Decode generated tokens generated_text = self._tokenizer.decode( output[0], - skip_special_tokens=True + skip_special_tokens=True, ) - + # Remove prompt from output if present if prompt in generated_text: generated_text = generated_text.split(prompt)[-1].strip() - + return generated_text diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 4ae24ce41..5ded02e7c 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -30,8 +30,8 @@ ) from .length_of_stay_stagenet_mimic4 import LengthOfStayStageNetMIMIC4 from .medical_coding import MIMIC3ICD9Coding -from .medical_transcriptions_classification import MedicalTranscriptionsClassification from .medical_vqa_task import MedicalVQATask +from .medical_transcriptions_classification import MedicalTranscriptionsClassification from .mortality_prediction import ( MortalityPredictionEICU, MortalityPredictionEICU2, diff --git a/pyhealth/tasks/medical_vqa_task.py b/pyhealth/tasks/medical_vqa_task.py index 86d616e0b..a4df18209 100644 --- a/pyhealth/tasks/medical_vqa_task.py +++ b/pyhealth/tasks/medical_vqa_task.py @@ -1,69 +1,101 @@ -"""Medical Visual Question Answering (VQA) task. +"""Medical Visual Question Answering task for the VQA-RAD dataset. -This module defines the task for medical VQA, where the model receives a -medical image and a natural-language question and must predict the correct -answer. The primary benchmark is VQA-RAD (Lau et al., 2018). +This module defines :class:`MedicalVQATask`, which converts raw VQA-RAD +patient events (each consisting of a radiology image, a clinical question, +and a free-text answer) into image-question-answer samples suitable for +multiclass classification. -The task frames VQA as **multiclass classification** over a closed answer -vocabulary extracted from the training set. This is the standard evaluation -protocol used by MedFlamingo (Moor et al., 2023) and other medical VQA -models on VQA-RAD. +The task frames VQA as **closed-set multiclass classification** over the +vocabulary of all answers seen during training. At inference time the model +selects the most probable answer from this fixed vocabulary. Open-ended +generation is supported separately via :meth:`~pyhealth.models.MedFlamingo.generate`. + +Paper: + Lau et al. "A dataset of clinically generated visual questions and + answers about radiology images." Scientific Data 5, 180251 (2018). + https://doi.org/10.1038/sdata.2018.251 """ from typing import Any, Dict, List +from ..data import Patient from .base_task import BaseTask class MedicalVQATask(BaseTask): - """Task for medical Visual Question Answering (VQA). + """Task for medical visual question answering on the VQA-RAD dataset. - Expects a dataset with medical images, questions, and answers. Each - sample maps an (image, question) pair to a single answer string, - treated as a multiclass classification label. + Each sample pairs a radiology image with a clinical question and maps + the corresponding free-text answer to a class index. The full answer + vocabulary is inferred from the training split by the PyHealth processor + pipeline. - Attributes: - task_name: ``"MedicalVQA"``. - input_schema: ``{"image": "image", "question": "text"}``. - output_schema: ``{"answer": "multiclass"}``. + Input schema: + - ``image`` (``"image"``): A radiology image path, processed by + :class:`~pyhealth.processors.ImageProcessor` into a + ``(3, 224, 224)`` float tensor. + - ``question`` (``"text"``): A free-text clinical question string + (returned as-is by :class:`~pyhealth.processors.TextProcessor`). + + Output schema: + - ``answer`` (``"multiclass"``): The free-text answer string, encoded + as an integer class index by + :class:`~pyhealth.processors.MulticlassProcessor`. - Note: - The ``"text"`` processor for ``"question"`` will tokenize the - question string. If your model needs raw strings instead, you - can override the processor in ``dataset.set_task()``. The assumed - schema here is a reasonable default -- adjust once Teammate A - confirms the final field names and processor types. + Attributes: + task_name: Unique identifier used for cache-key generation. + input_schema: Maps feature names to their processor type strings. + output_schema: Maps label names to their processor type strings. Examples: - >>> from pyhealth.datasets import VQARADDataset >>> from pyhealth.tasks import MedicalVQATask - >>> dataset = VQARADDataset(root="/path/to/vqarad") >>> task = MedicalVQATask() - >>> samples = dataset.set_task(task) + >>> task.task_name + 'MedicalVQA' + >>> task.input_schema + {'image': 'image', 'question': 'text'} + >>> task.output_schema + {'answer': 'multiclass'} """ task_name: str = "MedicalVQA" input_schema: Dict[str, str] = {"image": "image", "question": "text"} output_schema: Dict[str, str] = {"answer": "multiclass"} - def __call__(self, patient: Any) -> List[Dict[str, Any]]: - """Process a patient's VQA data into samples. + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + """Convert a VQA-RAD patient's events into image-question-answer samples. - Each event in the ``"vqarad"`` table becomes one (image, question, - answer) sample. + Iterates over all events of type ``"vqarad"`` attached to ``patient`` + and emits one sample dict per event. Events without a valid + ``image_path`` are included; the downstream + :class:`~pyhealth.processors.ImageProcessor` will raise an error if + the path does not point to a readable image file. Args: - patient: A patient object from :class:`~pyhealth.datasets.VQARADDataset`. + patient: A :class:`~pyhealth.data.Patient` object whose events + were populated by :class:`~pyhealth.datasets.VQARADDataset`. Returns: - A list of sample dicts, each with keys ``"image"``, - ``"question"``, and ``"answer"``. + A list of sample dicts, each with the keys: + + - ``"patient_id"`` (:class:`str`): The patient identifier. + - ``"image"`` (:class:`str`): Absolute path to the radiology image. + - ``"question"`` (:class:`str`): The clinical question text. + - ``"answer"`` (:class:`str`): The free-text answer string (will be + encoded as an integer by the multiclass processor). + + Example: + >>> # Typically called internally by BaseDataset.set_task() + >>> samples = dataset.set_task(MedicalVQATask()) + >>> samples[0].keys() + dict_keys(['patient_id', 'image', 'question', 'answer']) """ + samples = [] events = patient.get_events(event_type="vqarad") - samples: List[Dict[str, Any]] = [] for event in events: samples.append( { + "patient_id": patient.patient_id, "image": event.image_path, "question": event.question, "answer": event.answer, diff --git a/tests/core/test_medflamingo.py b/tests/core/test_medflamingo.py index d527f2c37..7c190edc4 100644 --- a/tests/core/test_medflamingo.py +++ b/tests/core/test_medflamingo.py @@ -1,116 +1,550 @@ -"""Test cases for the MedFlamingo model stub.""" +"""Tests for MedFlamingo model, VQARADDataset, and MedicalVQATask. +All tests use synthetic / pseudo data generated in memory or in temporary +directories. No real datasets, internet access, or heavyweight model weights +are required. The ``TestableMedFlamingo`` subclass replaces the production +CLIP vision encoder and OPT language model with lightweight stubs so the +entire test suite completes in under a few seconds on CPU. +""" + +import json +import os +import shutil +import tempfile import unittest +import warnings +from types import SimpleNamespace +from PIL import Image import torch +import torch.nn as nn +from pyhealth.data import Patient, Event +from pyhealth.datasets import ( + VQARADDataset, + create_sample_dataset, + get_dataloader, + split_by_sample, +) from pyhealth.models.base_model import BaseModel -from pyhealth.models.medflamingo import MedFlamingo, MedFlamingoLayer - - -class TestMedFlamingoLayer(unittest.TestCase): - """Test cases for MedFlamingoLayer.""" - - def test_layer_initialization_defaults(self): - """Test that MedFlamingoLayer initializes with default params.""" - layer = MedFlamingoLayer() - self.assertEqual(layer.vision_dim, 768) - self.assertEqual(layer.lang_dim, 1024) - self.assertEqual(layer.num_resampler_tokens, 64) - self.assertEqual(layer.num_resampler_layers, 6) - self.assertEqual(layer.num_heads, 8) - self.assertEqual(layer.dropout, 0.0) - - def test_layer_custom_params(self): - """Test MedFlamingoLayer with custom dimensions.""" - layer = MedFlamingoLayer( - vision_dim=512, - lang_dim=2048, - num_resampler_tokens=32, - num_resampler_layers=4, - num_heads=16, - dropout=0.1, +from pyhealth.models.medflamingo import MedFlamingo +from pyhealth.tasks import MedicalVQATask +from pyhealth.trainer import Trainer + + +REAL_VQARAD_ROOT = os.getenv("PYHEALTH_VQARAD_ROOT") + +warnings.filterwarnings( + "ignore", + message=r"A newer version of litdata is available .*", + category=UserWarning, +) + + +# --------------------------------------------------------------------------- +# Lightweight model stubs (no CLIP / OPT downloads) +# --------------------------------------------------------------------------- + + +class FakeBatch(dict): + def to(self, device): + return FakeBatch({key: value.to(device) for key, value in self.items()}) + + +class FakeTokenizer: + def __init__(self): + self.pad_token = None + self.eos_token = "" + self.last_text = "" + + def __call__( + self, + texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ): + if isinstance(texts, str): + texts = [texts] + self.last_text = texts[0] + seq_len = min(max(len(text.split()) for text in texts) + 1, max_length) + input_ids = [] + attention_mask = [] + for row, text in enumerate(texts): + tokens = [(row + idx) % 17 + 1 for idx, _ in enumerate(text.split()[:seq_len])] + tokens = tokens + [0] * (seq_len - len(tokens)) + mask = [1 if token != 0 else 0 for token in tokens] + if not any(mask): + tokens[0] = 1 + mask[0] = 1 + input_ids.append(tokens) + attention_mask.append(mask) + return FakeBatch( + { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "attention_mask": torch.tensor(attention_mask, dtype=torch.long), + } ) - self.assertEqual(layer.vision_dim, 512) - self.assertEqual(layer.lang_dim, 2048) - self.assertEqual(layer.num_resampler_tokens, 32) - self.assertEqual(layer.num_resampler_layers, 4) - self.assertEqual(layer.num_heads, 16) - self.assertEqual(layer.dropout, 0.1) - - def test_layer_forward_raises(self): - """Test that forward raises NotImplementedError (stub).""" - layer = MedFlamingoLayer() - lang_hidden = torch.randn(2, 10, 1024) - vision_features = torch.randn(2, 196, 768) - with self.assertRaises(NotImplementedError): - layer(lang_hidden, vision_features) - - def test_layer_is_nn_module(self): - """Test that MedFlamingoLayer is an nn.Module.""" - layer = MedFlamingoLayer() - self.assertIsInstance(layer, torch.nn.Module) + + def decode(self, tokens, skip_special_tokens=True): + return f"{self.last_text} synthetic answer" + + +class FakeLanguageInnerModel(nn.Module): + def __init__(self, vocab_size=32, hidden_size=8): + super().__init__() + self.embed_tokens = nn.Embedding(vocab_size, hidden_size) + + +class FakeLanguageModel(nn.Module): + def __init__(self, hidden_size=8, num_hidden_layers=4): + super().__init__() + self.config = SimpleNamespace( + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + ) + self.model = FakeLanguageInnerModel(hidden_size=hidden_size) + + def generate( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + max_new_tokens=16, + **kwargs, + ): + # Accept either input_ids or inputs_embeds; generate() passes inputs_embeds + # so that the xattn-conditioned representations are forwarded to the LLM. + if inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + device = inputs_embeds.device + else: + batch_size = input_ids.shape[0] + device = input_ids.device + return torch.full( + (batch_size, min(max_new_tokens, 4)), + fill_value=7, + dtype=torch.long, + device=device, + ) + + +class FakeVisionEncoder(nn.Module): + def __init__(self, hidden_size=8, num_tokens=5): + super().__init__() + self.config = SimpleNamespace(hidden_size=hidden_size) + self.num_tokens = num_tokens + self.proj = nn.Linear(1, hidden_size) + + def forward(self, pixel_values): + batch_size = pixel_values.shape[0] + pooled = pixel_values.float().reshape(batch_size, -1).mean(dim=1, keepdim=True) + repeated = pooled.unsqueeze(1).repeat(1, self.num_tokens, 1) + return SimpleNamespace(last_hidden_state=self.proj(repeated)) + + +class TestableMedFlamingo(MedFlamingo): + __test__ = False + + def _init_vision_encoder(self) -> None: + self._vision_encoder = FakeVisionEncoder() + if self.freeze_vision: + for param in self._vision_encoder.parameters(): + param.requires_grad = False + + def _init_lang_model(self) -> None: + self._lang_model = FakeLanguageModel() + self._tokenizer = FakeTokenizer() + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token + if self.freeze_lm: + for param in self._lang_model.parameters(): + param.requires_grad = False + + +# --------------------------------------------------------------------------- +# Test suite +# --------------------------------------------------------------------------- class TestMedFlamingo(unittest.TestCase): - """Test cases for the MedFlamingo model.""" + @classmethod + def setUpClass(cls): + cls.temp_dir = tempfile.mkdtemp() + cls.vqarad_root = tempfile.mkdtemp() + cls.vqarad_cache_dir = tempfile.mkdtemp() + cls.samples = [] + labels = ["yes", "no", "yes", "no"] + questions = [ + "is there a fracture", + "is the study normal", + "is there consolidation", + "is there edema", + ] + + for idx, (answer, question) in enumerate(zip(labels, questions)): + image_path = os.path.join(cls.temp_dir, f"img_{idx}.png") + image = Image.fromarray( + torch.randint(0, 255, (16, 16, 3), dtype=torch.uint8).numpy(), + mode="RGB", + ) + image.save(image_path) + cls.samples.append( + { + "patient_id": f"patient-{idx // 2}", + "visit_id": f"visit-{idx}", + "image": image_path, + "question": question, + "answer": answer, + } + ) + + cls.dataset = create_sample_dataset( + samples=cls.samples, + input_schema={ + "image": ("image", {"image_size": 16, "mode": "RGB"}), + "question": "text", + }, + output_schema={"answer": "multiclass"}, + dataset_name="test_medflamingo", + ) + + cls._create_vqarad_fixture( + cls.vqarad_root, + num_examples=8, + ) + + @classmethod + def _create_vqarad_fixture(cls, root, num_examples): + images_dir = os.path.join(root, "images") + os.makedirs(images_dir, exist_ok=True) + entries = [] + answers = ["yes", "no"] * (num_examples // 2) + questions = [ + "is there a fracture", + "is the study normal", + "is there consolidation", + "is there edema", + "is there a mass", + "is there pleural effusion", + "is there cardiomegaly", + "is there pneumothorax", + ] + + for idx in range(num_examples): + image_name = f"study_{idx}.png" + image_path = os.path.join(images_dir, image_name) + image = Image.fromarray( + torch.randint(0, 255, (16, 16, 3), dtype=torch.uint8).numpy(), + mode="RGB", + ) + image.save(image_path) + entries.append( + { + "IMAGE_PATH": image_name, + "QUESTION": questions[idx % len(questions)], + "ANSWER": answers[idx % len(answers)], + "ANSWER_TYPE": "closed", + "QUESTION_TYPE": "presence", + "IMAGE_ORGAN": "chest", + } + ) + + with open(os.path.join(root, "VQA_RAD Dataset Public.json"), "w") as f: + json.dump(entries, f) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.temp_dir) + shutil.rmtree(cls.vqarad_root) + shutil.rmtree(cls.vqarad_cache_dir) + + def _build_vqarad_sample_dataset(self): + dataset = VQARADDataset( + root=self.vqarad_root, + cache_dir=self.vqarad_cache_dir, + num_workers=1, + ) + return dataset.set_task(num_workers=1) + + # ------------------------------------------------------------------ + # MedicalVQATask unit tests + # ------------------------------------------------------------------ + + def test_medical_vqa_task_schema(self): + """Task declares the expected input/output schema.""" + task = MedicalVQATask() + self.assertEqual(task.task_name, "MedicalVQA") + self.assertEqual(task.input_schema, {"image": "image", "question": "text"}) + self.assertEqual(task.output_schema, {"answer": "multiclass"}) + + def test_medical_vqa_task_call_emits_correct_fields(self): + """__call__ returns one sample per vqarad event with all required keys.""" + import polars as pl + from datetime import datetime + + task = MedicalVQATask() + + # Patient expects a Polars DataFrame with columns: + # event_type, timestamp, vqarad/ + rows = [ + { + "event_type": "vqarad", + "timestamp": datetime(2020, 1, i + 1), + "vqarad/image_path": f"/data/images/img_{i}.jpg", + "vqarad/question": f"Is there a fracture? ({i})", + "vqarad/answer": "yes" if i % 2 == 0 else "no", + } + for i in range(3) + ] + df = pl.DataFrame(rows) + patient = Patient(patient_id="p-001", data_source=df) + + samples = task(patient) + + self.assertEqual(len(samples), 3) + for sample in samples: + self.assertIn("patient_id", sample) + self.assertIn("image", sample) + self.assertIn("question", sample) + self.assertIn("answer", sample) + self.assertEqual(sample["patient_id"], "p-001") + + def test_medical_vqa_task_call_empty_patient(self): + """__call__ returns an empty list when the patient has no vqarad events.""" + import polars as pl + + task = MedicalVQATask() + # DataFrame with required columns but zero rows + df = pl.DataFrame({"event_type": [], "timestamp": []}).with_columns( + pl.col("timestamp").cast(pl.Datetime) + ) + patient = Patient(patient_id="p-empty", data_source=df) + self.assertEqual(task(patient), []) + + # ------------------------------------------------------------------ + # MedFlamingo model unit tests + # ------------------------------------------------------------------ def test_model_initialization_standalone(self): - """Test MedFlamingo initializes without a dataset.""" - model = MedFlamingo(dataset=None) + """Standalone model (no dataset) initialises with expected defaults.""" + model = TestableMedFlamingo(dataset=None) self.assertIsInstance(model, MedFlamingo) + self.assertIsInstance(model, BaseModel) self.assertEqual(model.vision_model_name, "openai/clip-vit-large-patch14") self.assertEqual(model.lang_model_name, "facebook/opt-6.7b") - self.assertIsNone(model.medflamingo_checkpoint) - self.assertEqual(model.cross_attn_every_n_layers, 4) - self.assertEqual(model.num_resampler_tokens, 64) - self.assertTrue(model.freeze_vision) - self.assertTrue(model.freeze_lm) - - def test_model_custom_params(self): - """Test MedFlamingo with custom model names and config.""" - model = MedFlamingo( - dataset=None, - vision_model_name="openai/clip-vit-base-patch32", - lang_model_name="facebook/opt-1.3b", - cross_attn_every_n_layers=2, - num_resampler_tokens=32, - freeze_vision=False, + # FakeLanguageModel has 4 hidden layers; cross_attn_every_n_layers=4 + # yields exactly 1 xattn layer (4 // 4 = 1). + self.assertEqual(len(model._xattn_layers), 1) + self.assertEqual(model._tokenizer.pad_token, model._tokenizer.eos_token) + # _fc must be None when no dataset is supplied + self.assertIsNone(model._fc) + + def test_forward_smoke_with_dataset_batch(self): + """forward() returns all required keys with correct batch and class dimensions.""" + model = TestableMedFlamingo(dataset=self.dataset) + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + output = model(**batch) + + self.assertIn("loss", output) + self.assertIn("y_prob", output) + self.assertIn("y_true", output) + self.assertIn("logit", output) + # Batch dimension + self.assertEqual(output["logit"].shape[0], 2) + self.assertEqual(output["y_prob"].shape[0], 2) + self.assertEqual(output["y_true"].shape[0], 2) + # Class dimension must match the vocabulary size inferred by the processor + expected_num_classes = self.dataset.output_processors["answer"].size() + self.assertEqual(output["logit"].shape[1], expected_num_classes) + self.assertEqual(output["y_prob"].shape[1], expected_num_classes) + + def test_generate_smoke_single_image(self): + """generate() returns a non-empty string for a single image + prompt.""" + model = TestableMedFlamingo(dataset=None) + response = model.generate( + images=[torch.randn(3, 16, 16)], + prompt="what does the image show", + max_new_tokens=8, ) - self.assertEqual(model.vision_model_name, "openai/clip-vit-base-patch32") - self.assertEqual(model.lang_model_name, "facebook/opt-1.3b") - self.assertEqual(model.cross_attn_every_n_layers, 2) - self.assertEqual(model.num_resampler_tokens, 32) - self.assertFalse(model.freeze_vision) - - def test_forward_raises(self): - """Test that forward raises NotImplementedError (stub).""" - model = MedFlamingo(dataset=None) - with self.assertRaises(NotImplementedError): - model.forward() - - def test_generate_raises(self): - """Test that generate raises NotImplementedError (stub).""" - model = MedFlamingo(dataset=None) - dummy_image = torch.randn(3, 224, 224) - with self.assertRaises(NotImplementedError): - model.generate(images=[dummy_image], prompt="What is shown?") - - def test_inherits_base_model(self): - """Test that MedFlamingo inherits from BaseModel.""" - model = MedFlamingo(dataset=None) - self.assertIsInstance(model, BaseModel) - def test_standalone_has_empty_keys(self): - """Test that standalone model has empty feature/label keys.""" - model = MedFlamingo(dataset=None) - self.assertEqual(model.feature_keys, []) - self.assertEqual(model.label_keys, []) + self.assertIsInstance(response, str) + self.assertIn("synthetic answer", response) + + def test_generate_smoke_with_few_shot_examples(self): + """generate() returns a string when few-shot context images are provided.""" + model = TestableMedFlamingo(dataset=None) + response = model.generate( + images=[torch.randn(3, 16, 16)], + prompt="what is the main finding", + few_shot_examples=[ + { + "image": torch.randn(3, 16, 16), + "text": "Q: is there a fracture?\nA: no", + } + ], + max_new_tokens=8, + ) + + self.assertIsInstance(response, str) + self.assertIn("synthetic answer", response) + + def test_generate_uses_inputs_embeds(self): + """generate() passes inputs_embeds (not input_ids) so xattn conditioning applies.""" + seen_kwargs = {} + + original_generate = FakeLanguageModel.generate + + def patched_generate(self, **kwargs): + seen_kwargs.update(kwargs) + return original_generate(self, **kwargs) + + model = TestableMedFlamingo(dataset=None) + model._lang_model.generate = lambda **kw: (seen_kwargs.update(kw) or original_generate(model._lang_model, **kw)) + + model.generate( + images=[torch.randn(3, 16, 16)], + prompt="is there a fracture", + max_new_tokens=4, + ) + + self.assertIn("inputs_embeds", seen_kwargs) + self.assertNotIn("input_ids", seen_kwargs) + + def test_gradients_flow_through_xattn_layers(self): + """Only xattn layers and the classification head receive gradients.""" + model = TestableMedFlamingo(dataset=self.dataset) + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + output = model(**batch) + output["loss"].backward() + + trainable_with_grad = { + name + for name, param in model.named_parameters() + if param.requires_grad and param.grad is not None + } + + # xattn layers must receive gradients + self.assertTrue( + any(name.startswith("_xattn_layers") for name in trainable_with_grad) + ) + # Frozen vision encoder must NOT receive gradients + self.assertFalse( + any(name.startswith("_vision_encoder") for name in trainable_with_grad) + ) + # Frozen language model must NOT receive gradients + self.assertFalse( + any(name.startswith("_lang_model") for name in trainable_with_grad) + ) + # Classification head must receive gradients + self.assertTrue(any(name.startswith("_fc") for name in trainable_with_grad)) + # No other parameters should have gradients + self.assertEqual( + { + name + for name in trainable_with_grad + if not (name.startswith("_xattn_layers") or name.startswith("_fc")) + }, + set(), + msg="Unexpected parameters received gradients", + ) + + # ------------------------------------------------------------------ + # VQARADDataset integration tests + # ------------------------------------------------------------------ + + def test_forward_smoke_with_vqarad_dataset_batch(self): + """forward() works end-to-end on a batch from the VQARADDataset pipeline.""" + samples = self._build_vqarad_sample_dataset() + try: + model = TestableMedFlamingo(dataset=samples) + loader = get_dataloader(samples, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + output = model(**batch) + + self.assertIn("loss", output) + self.assertIn("y_prob", output) + self.assertIn("y_true", output) + self.assertIn("logit", output) + self.assertEqual(output["logit"].shape[0], 2) + self.assertEqual( + output["logit"].shape[1], + samples.output_processors["answer"].size(), + ) + finally: + samples.close() + + @unittest.skipUnless( + REAL_VQARAD_ROOT, + "set PYHEALTH_VQARAD_ROOT to run the real VQA-RAD batch smoke test", + ) + def test_forward_with_real_vqarad_batch_if_available(self): + real_cache_dir = tempfile.mkdtemp() + try: + dataset = VQARADDataset( + root=REAL_VQARAD_ROOT, + cache_dir=real_cache_dir, + num_workers=1, + dev=True, + ) + samples = dataset.set_task(num_workers=1) + try: + model = TestableMedFlamingo(dataset=samples) + loader = get_dataloader(samples, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + output = model(**batch) + + self.assertIn("loss", output) + self.assertIn("y_prob", output) + self.assertIn("y_true", output) + self.assertIn("logit", output) + finally: + samples.close() + finally: + shutil.rmtree(real_cache_dir) + + def test_trainer_with_small_vqarad_sample(self): + """Trainer.train() and Trainer.evaluate() complete without error on tiny data.""" + samples = self._build_vqarad_sample_dataset() + try: + train_dataset, val_dataset, test_dataset = split_by_sample( + samples, + [0.5, 0.25, 0.25], + seed=42, + ) + train_loader = get_dataloader(train_dataset, batch_size=2, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=2, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=2, shuffle=False) + + model = TestableMedFlamingo(dataset=samples) + trainer = Trainer( + model=model, + metrics=["accuracy"], + device="cpu", + enable_logging=False, + ) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=1, + load_best_model_at_last=False, + ) + scores = trainer.evaluate(test_loader) - def test_device_property(self): - """Test that the device property works (inherited from BaseModel).""" - model = MedFlamingo(dataset=None) - self.assertIsInstance(model.device, torch.device) + self.assertIn("loss", scores) + self.assertIn("accuracy", scores) + finally: + samples.close() if __name__ == "__main__":