# Multimodal Cancer Modeling in the Age of Foundation Model Embeddings
### Steven Song\*, Morgan Borjigin-Wang\*, Irene R. Madejski, Robert L. Grossman

\* Equal contribution

Read our paper here: https://arxiv.org/abs/2505.07683

***

The Cancer Genome Atlas (TCGA) has enabled novel discoveries and served as a large-scale reference dataset in cancer through its harmonized genomics, clinical, and imaging data. Numerous prior studies have developed bespoke deep learning models over TCGA for tasks such as cancer survival prediction. A modern paradigm in biomedical deep learning is the development of foundation models (FMs) to derive feature embeddings agnostic to a specific modeling task. Biomedical text especially has seen growing development of FMs. While TCGA contains free-text data as pathology reports, these have been historically underutilized.

* **We investigate the ability to train classical machine learning models over multimodal, zero-shot FM embeddings of cancer data.**
* We demonstrate the ease and additive effect of multimodal fusion, outperforming unimodal models.
* Overall, we propose a simple, modernized approach to multimodal cancer modeling using FM embeddings.

### Overview

<img src="https://raw.githubusercontent.com/StevenSong/multimodal-cancer-modeling/refs/heads/main/overview.png" alt="conceptual overview figure" width="50%"/>

Conceptually, the proposed framework does late fusion of unimodal models trained over their respective embeddings. Specifically, we use:
* BulkRNABert ([Gélard et al. 2025)](https://proceedings.mlr.press/v259/gelard25a.html)) for RNA-seq data
* UNI2-h ([Chen et al. 2024](https://www.nature.com/articles/s41591-024-02857-3)) for histology data
* BioMistral ([Labrak et al. 2024](https://aclanthology.org/2024.findings-acl.348/)) for pathology report data (summarized by Llama-3.1-8B-Instruct ([Grattafiori et al. 2024](https://arxiv.org/abs/2407.21783)))

We can specifically break down the required steps for our approach as follows:
1. Pull TCGA-LUAD embeddings
1. Pull TCGA-LUAD metadata
1. Merge embeddings and metadata
1. Prepare experiments
1. Train unimodal models
1. Train multimodal model
1. Evaluate models

In [None]:
!pip install h5py pandas numpy tqdm pqdm scikit-learn scikit-survival jupyterlab ipywidgets jsonrpcclient gen3

In [1]:
import json
import requests
from io import StringIO

import pandas as pd
import numpy as np
from tqdm import tqdm
from pqdm.threads import pqdm
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored
from jsonrpcclient import request, parse_json, Ok, Error
from gen3.auth import Gen3Auth

#### Pull TCGA Embeddings
This retrieves embeddings using the AI commons embedding service!

In [2]:
auth = Gen3Auth()

expr_embs = dict()
hist_embs = dict()
text_embs = dict()

def query(guid):
    response = requests.get(
        f"https://m3aicommons.org/user/data/download/{guid}",
        auth=auth,
    )
    response.raise_for_status()
    return response.json()["url"]

for fpath, embs, do_upper in [
    ("manifests/expr-luad.json", expr_embs, False),
    ("manifests/hist-luad.json", hist_embs, False),
    ("manifests/summ-luad.json", text_embs, True),
]:
    with open(fpath, "r") as f:
        guids = [x["guid"] for x in json.load(f)]
    temp = pqdm(guids, query, n_jobs=500) # ongoing work to batch download
    for x in tqdm(temp):
        fid = x["file_id"]
        if do_upper: # some modalities need uppercased keys, artifact of upstream data preprocessing (from other research groups) mistmatch
            fid = fid.upper()
        embs[fid] = np.asarray(json.loads(x["embedding"]))

QUEUEING TASKS | :   0%|          | 0/439 [00:00<?, ?it/s]

PROCESSING TASKS | :   0%|          | 0/439 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/439 [00:00<?, ?it/s]

100%|██████████| 439/439 [00:00<00:00, 11547.02it/s]


QUEUEING TASKS | :   0%|          | 0/433 [00:00<?, ?it/s]

PROCESSING TASKS | :   0%|          | 0/433 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/433 [00:00<?, ?it/s]

100%|██████████| 433/433 [00:00<00:00, 3080.44it/s]


QUEUEING TASKS | :   0%|          | 0/488 [00:00<?, ?it/s]

PROCESSING TASKS | :   0%|          | 0/488 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/488 [00:00<?, ?it/s]

100%|██████████| 488/488 [00:00<00:00, 1152.36it/s]


#### Pull TCGA Metadata

In our embedding service, the only metadata we store with the embeddings is the filename. This is done to prevent extensive duplication of metadata columns from the GDC. We can use the GDC API itself to get a mapping of the filenames to other metadata, such as case ID and survival outcomes.

Here, we also demo using the GDC Cohort Copilot as an MCP tool to construct GDC cohort filters. Ideally this would be done using an LLM, however we can also manually use the tool! This allows us to more easily construct the aforementioned metadata mappings.

In [3]:
endpoint = "https://apps.m3aicommons.org/gdc-cohort-copilot/gradio_api/mcp/"
headers = {"Accept": "application/json, text/event-stream"}

# helper to parse the SSE-formatted response from the MCP server
def parse_sse(response):
    jsonrpc_str = response.text.split("\n", 1)[1].replace("data: ", "", 1)
    match parse_json(jsonrpc_str):
        case Ok(result, id_):
            return result
        case Error(code, message, data, id_):
            raise ValueError(message)

# helper to use the MCP tool to construct a GDC cohort filter
def get_cohort_filter(query: str) -> str:
    response = requests.post(
        endpoint,
        json=request(
            method="tools/call",
            params={
                "name": "generate_filter",
                "arguments": {
                    "query": query,
                },
            },
        ),
        headers=headers,
    )

    result = parse_sse(response)
    return result["content"][0]["text"]

# helper to convert GDC request to file mapping
def get_file_mapping(cohort_filter):
    df = pd.read_csv(
        StringIO(
            requests.get(
                "https://api.gdc.cancer.gov/files",
                params={
                    "filters": cohort_filter,
                    "fields": ",".join(
                        [
                            "file_name",
                            "cases.project.project_id",
                            "cases.submitter_id",
                            "experimental_strategy",
                            "file_size",
                            "md5sum",
                            "state",
                        ]
                    ),
                    "format": "TSV",
                    "size": str(100_000),
                },
            ).text
        ),
        sep="\t",
    )
    mapping = df[["cases.0.submitter_id", "file_name"]].rename(
        columns={
            "cases.0.submitter_id": "case_id",
        }
    )
    mapping["file_name"] = (
        mapping["file_name"]
        .str.replace(".svs", "")
        .str.replace(".rna_seq.augmented_star_gene_counts.tsv", "")
        .str.replace(".PDF", "")
    )
    return mapping.set_index("case_id")["file_name"]

In [4]:
# Get the tool list
response = requests.post(
    endpoint,
    json=request(method="tools/list"),
    headers=headers,
)

print(json.dumps(parse_sse(response), indent=4))

{
    "tools": [
        {
            "name": "generate_filter",
            "description": "Converts a free text description of a cancer cohort into a GDC structured cohort filter. Returns: str: JSON structured GDC cohort filter",
            "inputSchema": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "The free text cohort description"
                    }
                }
            }
        }
    ]
}


In [5]:
# Construct some cohort filters, you can print these to see what they look like!
hist_filter = get_cohort_filter("TCGA-LUAD samples with diagnostic slides as svs files")
expr_filter = get_cohort_filter("TCGA-LUAD samples with RNA-sequences gene expression quantification data as tsv files")
text_filter = get_cohort_filter("TCGA-LUAD samples with pathology report pdf files")

In [6]:
print(hist_filter)
print(expr_filter)
print(text_filter)

{
    "op": "and",
    "content": [
        {
            "op": "in",
            "content": {
                "field": "cases.project.project_id",
                "value": [
                    "TCGA-LUAD"
                ]
            }
        },
        {
            "op": "in",
            "content": {
                "field": "files.experimental_strategy",
                "value": [
                    "Diagnostic Slide"
                ]
            }
        },
        {
            "op": "in",
            "content": {
                "field": "files.data_format",
                "value": [
                    "svs"
                ]
            }
        }
    ]
}
{
    "op": "and",
    "content": [
        {
            "op": "in",
            "content": {
                "field": "cases.project.project_id",
                "value": [
                    "TCGA-LUAD"
                ]
            }
        },
        {
            "op": "in",
            "content": {
         

In [7]:
# Query the API using the filters to get a mapping of case ID --> filename (from the embedding service)
hist_mapping = get_file_mapping(hist_filter)
expr_mapping = get_file_mapping(expr_filter)
text_mapping = get_file_mapping(text_filter)

In [8]:
# Get clinical metadata, this is standard code to pull data from the GDC
# (though we use an extra call to the GDC Cohort Copilot for fun)
response = requests.get(
    "https://api.gdc.cancer.gov/cases",
    params={
        "filters": get_cohort_filter("samples from TCGA-LUAD"),
        "fields": ",".join(
            [
                "project.project_id",
                "submitter_id",
                "diagnoses.age_at_diagnosis",
                "diagnoses.diagnosis_is_primary_disease",
                "diagnoses.days_to_last_follow_up",
                "demographic.days_to_death",
                "demographic.vital_status",
                "demographic.ethnicity",
                "demographic.gender",
                "demographic.race",
            ]
        ),
        "format": "JSON",
        "size": str(100_000),
    },
)
temp = json.loads(response.content)

data = []
for sample in temp["data"]["hits"]:
    if "diagnoses" not in sample:
        continue
    for diagnosis in sample["diagnoses"]:
        if "diagnosis_is_primary_disease" in diagnosis and diagnosis["diagnosis_is_primary_disease"]:
            break
    else:
        # no primary disease, skip
        continue
    if "days_to_last_follow_up" not in diagnosis:
        # primary diagnosis does not have f/u
        continue
    data.append({
        "project_id": sample["project"]["project_id"],
        "case_id": sample["submitter_id"],
        "age_at_diagnosis": diagnosis["age_at_diagnosis"],
        "days_to_last_fu": diagnosis["days_to_last_follow_up"],
        "sex": sample["demographic"]["gender"],
        "race": sample["demographic"]["race"],
        "ethnicity": sample["demographic"]["ethnicity"],
        "vital_status": sample["demographic"]["vital_status"],
    })

#### Merge Embeddings and Metadata

In [9]:
case_ids = ( # case IDs that have all 3 embedding modalities
    set(expr_mapping[expr_mapping.isin(expr_embs)].index) &
    set(hist_mapping[hist_mapping.isin(hist_embs)].index) &
    set(text_mapping[text_mapping.isin(text_embs)].index)
)

In [10]:
df = pd.DataFrame(data)
df = df[
    df["case_id"].isin(case_ids)
    & df["vital_status"].isin({"Alive", "Dead"})
    & df["age_at_diagnosis"].notna()
    & df["days_to_last_fu"].notna()
    & (df["days_to_last_fu"] >= 0)
]
df["age_at_diagnosis"] /= 365.24 # convert age to years
df["age_binned"] = pd.cut(df["age_at_diagnosis"], bins=np.arange(0, 101, 20))

In [11]:
expr_X = []
hist_X = []
text_X = []
for _, row in df.iterrows():
    case_id = row["case_id"]
    for embs, mapping, X in [
        (expr_embs, expr_mapping, expr_X),
        (hist_embs, hist_mapping, hist_X),
        (text_embs, text_mapping, text_X),
    ]:
        fnames = mapping.loc[case_id]
        if isinstance(fnames, str):
            emb = embs[fnames] # fname is guaranteed to be in embeddings
        else:
            # at least one fname will be in embs
            emb = np.mean([embs[f] for f in fnames if f in embs], axis=0)
        X.append(emb)
expr_X = np.asarray(expr_X)
hist_X = np.asarray(hist_X)
text_X = np.asarray(text_X)

#### Prepare Experiments

In [12]:
demo_X = OneHotEncoder(drop="if_binary", sparse_output=False, dtype=np.float32).fit_transform(df[["sex", "age_binned", "race", "ethnicity"]])
# canc_X = OneHotEncoder(drop="if_binary", sparse_output=False, dtype=np.float32).fit_transform(df[["project_id"]]) # cancer type doesn't change for single project demo

In [13]:
y = np.array(
    list(zip(df["vital_status"] == "Dead", df["days_to_last_fu"])),
    dtype=[("Status", "?"), ("Survival_in_days", "<f8")],
)

In [14]:
splitter = ( # stratify by observed mortality and cancer type
    df["vital_status"]
    + df["project_id"]
)

( # split all data modalities into train/test
    demo_X_train,
    demo_X_test,
    # canc_X_train,
    # canc_X_test,
    expr_X_train,
    expr_X_test,
    hist_X_train,
    hist_X_test,
    text_X_train,
    text_X_test,
    y_train,
    y_test,
) = train_test_split(
    demo_X,
    # canc_X,
    expr_X,
    hist_X,
    text_X,
    y,
    test_size=0.2,
    random_state=42,
    stratify=splitter,
)

In [15]:
# this one helper will be used to run both unimodal and multimodal experiments
def train_eval_model(
    *,  # enforce kwargs
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
    pca_components: int | None,
    standardize: bool,
    name: str = "",
    verbose: bool = False,
) -> dict:
    if verbose:
        print(f"Running {name}")

    # z-score input features
    if standardize:
        if verbose:
            print("--standardized")
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)
    else:
        X_train_scaled = X_train
        X_test_scaled = X_test

    # dimensionality reduction
    if pca_components is not None:
        if verbose:
            print("--reduced")
        pca = PCA(n_components=pca_components, random_state=42)
        X_train_red = pca.fit_transform(X_train_scaled)
        X_test_red = pca.transform(X_test_scaled)
    else:
        X_train_red = X_train_scaled
        X_test_red = X_test_scaled

    # fit survival model
    cox = CoxPHSurvivalAnalysis(alpha=0.1).fit(X_train_red, y_train)
    if verbose:
        print("--trained")

    # generate predictions
    y_train_pred = cox.predict(X_train_red)
    y_test_pred = cox.predict(X_test_red)

    # evaluate predictions
    c_index = concordance_index_censored(
        event_indicator=y_test["Status"],
        event_time=y_test["Survival_in_days"],
        estimate=y_test_pred,
    )[0]

    return {
        "c_index": c_index,
        "y_test_pred": y_test_pred,
        "y_train_pred": y_train_pred,
        "coxph": cox,
    }

#### Train Unimodal Models

In [16]:
demo_results = train_eval_model(X_train=demo_X_train, y_train=y_train, X_test=demo_X_test, y_test=y_test, pca_components=None, standardize=False, name="demo", verbose=True)
# canc_results = train_eval_model(X_train=canc_X_train, y_train=y_train, X_test=canc_X_test, y_test=y_test, pca_components=None, standardize=False, name="canc", verbose=True)
expr_results = train_eval_model(X_train=expr_X_train, y_train=y_train, X_test=expr_X_test, y_test=y_test, pca_components=256, standardize=True, name="expr", verbose=True)
hist_results = train_eval_model(X_train=hist_X_train, y_train=y_train, X_test=hist_X_test, y_test=y_test, pca_components=256, standardize=True, name="hist", verbose=True)
text_results = train_eval_model(X_train=text_X_train, y_train=y_train, X_test=text_X_test, y_test=y_test, pca_components=256, standardize=True, name="text", verbose=True)

Running demo
--trained
Running expr
--standardized
--reduced
--trained
Running hist
--standardized
--reduced


  risk_set += np.exp(xw[k])
  risk_set2 += np.exp(xw[k])


--trained
Running text
--standardized
--reduced


  risk_set += np.exp(xw[k])
  risk_set2 += np.exp(xw[k])


--trained


#### Train Multimodal Model
The unimodal models' predicted risk scores are used as input to the multimodal fusion model.

In [17]:
fuse_X_train = np.asarray([
    demo_results["y_train_pred"],
    # canc_results["y_train_pred"],
    expr_results["y_train_pred"],
    hist_results["y_train_pred"],
    text_results["y_train_pred"],
]).T

fuse_X_test = np.asarray([
    demo_results["y_test_pred"],
    # canc_results["y_test_pred"],
    expr_results["y_test_pred"],
    hist_results["y_test_pred"],
    text_results["y_test_pred"],
]).T

fuse_results = train_eval_model(X_train=fuse_X_train, y_train=y_train, X_test=fuse_X_test, y_test=y_test, pca_components=None, standardize=True, name="fuse", verbose=True)

Running fuse
--standardized
--trained


#### Evaluate Models
Our multimodal fusion approach substantially beats all unimodal results!

In [18]:
results = pd.Series({
    "demo": demo_results["c_index"],
    # "canc": canc_results["c_index"],
    "expr": expr_results["c_index"],
    "hist": hist_results["c_index"],
    "text": text_results["c_index"],
    "fuse": fuse_results["c_index"],
}, name="C-index")

print("Unimodal Results")
print("------------------")
display(results.loc[["demo", "expr", "hist", "text"]])

print("Multimodal Results")
print("------------------")
display(results.loc[["fuse"]])

Unimodal Results
------------------


demo    0.578692
expr    0.619855
hist    0.654560
text    0.686037
Name: C-index, dtype: float64

Multimodal Results
------------------


fuse    0.736885
Name: C-index, dtype: float64