# Multimodal Cancer Modeling in the Age of Foundation Model Embeddings
### Steven Song\*, Morgan Borjigin-Wang\*, Irene R. Madejski, Robert L. Grossman

\* Equal contribution

Read our paper here: https://arxiv.org/abs/2505.07683

***

The Cancer Genome Atlas (TCGA) has enabled novel discoveries and served as a large-scale reference dataset in cancer through its harmonized genomics, clinical, and imaging data. Numerous prior studies have developed bespoke deep learning models over TCGA for tasks such as cancer survival prediction. A modern paradigm in biomedical deep learning is the development of foundation models (FMs) to derive feature embeddings agnostic to a specific modeling task. Biomedical text especially has seen growing development of FMs. While TCGA contains free-text data as pathology reports, these have been historically underutilized.

* **We investigate the ability to train classical machine learning models over multimodal, zero-shot FM embeddings of cancer data.**
* We demonstrate the ease and additive effect of multimodal fusion, outperforming unimodal models.
* Overall, we propose a simple, modernized approach to multimodal cancer modeling using FM embeddings.

### Overview

<img src="https://raw.githubusercontent.com/StevenSong/multimodal-cancer-modeling/refs/heads/main/overview.png" alt="conceptual overview figure" width="50%"/>

Conceptually, the proposed framework does late fusion of unimodal models trained over their respective embeddings. Specifically, we use:
* BulkRNABert ([Gélard et al. 2025)](https://proceedings.mlr.press/v259/gelard25a.html)) for RNA-seq data
* UNI2-h ([Chen et al. 2024](https://www.nature.com/articles/s41591-024-02857-3)) for histology data
* BioMistral ([Labrak et al. 2024](https://aclanthology.org/2024.findings-acl.348/)) for pathology report data (summarized by Llama-3.1-8B-Instruct ([Grattafiori et al. 2024](https://arxiv.org/abs/2407.21783)))

We can specifically break down the required steps for our approach as follows:
1. Get TCGA embeddings
1. Get TCGA metadata
1. Merge embeddings and metadata
1. Prepare experiments
1. Train unimodal models
1. Train multimodal model
1. Evaluate models

# Survival Experiments

This notebook contains all our code for survival modeling.

The experiments test multimodal fusion of survival models and varying dimensionality reductions for high-dimensional embeddings.

Specifically, we experiment with 5 modalities:
* Patient demographics (sex, age - binned, race, ethnicity)
* Cancer type (we use the TCGA project ID as a proxy for cancer type)
* RNA-seq gene expression (`BulkRNABert` embeddings)
* Whole slide histology images (`UNI2` embeddings)
* Pathology reports (`BioMistral` embeddings)

We additionally experiment with various alternate embeddings, including:
* `BioMistral` embeddings of pathology report summaries generated by `Llama-3.1-8B-Instruct`
* `Mistral-7B-Instruct-v0.1` embeddings of pathology reports
* `Mistral-7B-Instruct-v0.1` embeddings of pathology report summaries generated by `Llama-3.1-8B-Instruct`
* `UCE` embeddings of RNA-seq gene expression

To use these alternate embeddings, modify the variables for input/output files in the first code cell of this notebook.

Run experiments by executing all cells of this notebook. Results are saved in the `results` subdirectory at the root of the repo. Analysis and visualization is done using tools also in the `results` folder.

In [1]:
expr_file = "expr.h5" # BulkRNABert
hist_file = "hist.h5" # UNI2
text_file = "summ.h5" # BioMistral - Summarized
output_results = "results/results_summarized.csv"
output_predictions = "results/predictions_summarized.npy"

In [2]:
from itertools import chain, combinations
from collections import defaultdict
import h5py
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import StratifiedKFold
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored

In [3]:
df = pd.read_csv("clinical.csv")
clin_case_ids = set(df["case_id"])

with h5py.File(expr_file, "r") as expr_h5:
    expr_case_ids = set(expr_h5.keys())

with h5py.File(hist_file, "r") as hist_h5:
    hist_case_ids = set(hist_h5.keys())

with h5py.File(text_file, "r") as text_h5:
    text_case_ids = set(text_h5.keys())

In [4]:
case_ids = sorted(list(clin_case_ids & expr_case_ids & hist_case_ids & text_case_ids))

df = df[df["case_id"].isin(case_ids)]
df = df.sort_values("case_id").reset_index(drop=True)
assert df["case_id"].is_unique

In [5]:
df.shape

(7982, 9)

In [6]:
df["age_binned"] = pd.cut(
    df["age"],
    bins=[0, 20, 40, 60, 80, 100],
    labels=["(0, 20]", "(20, 40]", "(40, 60]", "(60, 80]", "(80, 100]"],
)

In [7]:
dead = df["vital_status"] == "Dead"
days_to_event = np.where(dead, df["days_to_death"], df["days_to_last_follow_up"])
assert not np.isnan(days_to_event).any()

In [8]:
y = np.array(list(zip(dead, days_to_event)), dtype=[('Status', '?'), ('Survival_in_days', '<f8')])

In [9]:
demo_ohe = OneHotEncoder(drop="if_binary", sparse_output=False, dtype=np.float32)
canc_ohe = OneHotEncoder(drop="if_binary", sparse_output=False, dtype=np.float32)

In [10]:
demo_X = demo_ohe.fit_transform(df[["sex", "age_binned", "race", "ethnicity"]])
canc_X = canc_ohe.fit_transform(df[["project"]])

In [11]:
demo_X.shape

(7982, 17)

In [12]:
canc_X.shape

(7982, 32)

In [13]:
demo_ohe.categories_

[array(['female', 'male'], dtype=object),
 array(['(0, 20]', '(20, 40]', '(40, 60]', '(60, 80]', '(80, 100]'],
       dtype=object),
 array(['Unknown', 'american indian or alaska native', 'asian',
        'black or african american',
        'native hawaiian or other pacific islander', 'not reported',
        'white'], dtype=object),
 array(['Unknown', 'hispanic or latino', 'not hispanic or latino',
        'not reported'], dtype=object)]

In [14]:
canc_ohe.categories_

[array(['TCGA-ACC', 'TCGA-BLCA', 'TCGA-BRCA', 'TCGA-CESC', 'TCGA-CHOL',
        'TCGA-COAD', 'TCGA-DLBC', 'TCGA-ESCA', 'TCGA-GBM', 'TCGA-HNSC',
        'TCGA-KICH', 'TCGA-KIRC', 'TCGA-KIRP', 'TCGA-LGG', 'TCGA-LIHC',
        'TCGA-LUAD', 'TCGA-LUSC', 'TCGA-MESO', 'TCGA-OV', 'TCGA-PAAD',
        'TCGA-PCPG', 'TCGA-PRAD', 'TCGA-READ', 'TCGA-SARC', 'TCGA-SKCM',
        'TCGA-STAD', 'TCGA-TGCT', 'TCGA-THCA', 'TCGA-THYM', 'TCGA-UCEC',
        'TCGA-UCS', 'TCGA-UVM'], dtype=object)]

In [15]:
def extract_case_emb_from_h5(case_ids: list[str], h5: h5py.File):
    X = []
    for case_id in tqdm(case_ids):
        case_group = h5[case_id]
        embs = np.stack([v[:] for v in case_group.values()], axis=0)
        emb = np.mean(embs, axis=0)
        X.append(emb)
    return np.stack(X, axis=0)

In [16]:
with h5py.File(expr_file, "r") as expr_h5:
    expr_X = extract_case_emb_from_h5(case_ids, expr_h5)

with h5py.File(hist_file, "r") as hist_h5:
    hist_X = extract_case_emb_from_h5(case_ids, hist_h5)

with h5py.File(text_file, "r") as text_h5:
    text_X = extract_case_emb_from_h5(case_ids, text_h5)

100%|██████████| 7982/7982 [00:02<00:00, 2715.79it/s]
100%|██████████| 7982/7982 [00:05<00:00, 1460.10it/s]
100%|██████████| 7982/7982 [00:07<00:00, 1033.27it/s]


In [17]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
splitter = (
    df["vital_status"]
    + "_"
    + df["project"]
    + "_"
    + df["sex"]
    + "_"
    + df["age_binned"].astype(str)
    + "_"
    + df["race"]
    + "_"
    + df["ethnicity"]
)

n = len(df)
test_splits = [split_idxs for _, split_idxs in skf.split(X=np.zeros(n), y=splitter)]



In [18]:
splitter

0       Alive_TCGA-LUAD_male_(60, 80]_not reported_not...
1       Alive_TCGA-LUAD_male_(60, 80]_not reported_not...
2       Dead_TCGA-LUAD_female_(60, 80]_not reported_no...
3       Alive_TCGA-LUAD_male_(60, 80]_not reported_not...
4       Dead_TCGA-LUAD_male_(60, 80]_not reported_not ...
                              ...                        
7977    Alive_TCGA-LIHC_female_(60, 80]_white_not hisp...
7978    Alive_TCGA-LIHC_male_(40, 60]_white_not hispan...
7979    Alive_TCGA-THYM_female_(60, 80]_white_not hisp...
7980    Dead_TCGA-CHOL_male_(40, 60]_white_not hispani...
7981    Alive_TCGA-CESC_female_(60, 80]_white_not hisp...
Length: 7982, dtype: object

In [19]:
meta_df = df[["case_id"]].copy()
meta_df["split"] = -1
meta_df["split_order"] = -1
for i, test_idxs in enumerate(test_splits):
    meta_df.loc[test_idxs, "split"] = i
    meta_df.loc[test_idxs, "split_order"] = list(range(len(test_idxs)))
meta_df["dead"] = y["Status"]
meta_df["days_to_death_or_censor"] = y["Survival_in_days"]
meta_df.to_csv("results/split_cases.csv", index=False)

In [20]:
meta_df["split"].value_counts()

split
0    1597
1    1597
3    1596
4    1596
2    1596
Name: count, dtype: int64

In [21]:
def run_split(
    *,  # enforce kwargs
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_test: np.ndarray,
    y_test: np.ndarray,
    pca_components: int or None,
    standardize: bool,
    name: str = "",
    verbose: bool = False,
) -> dict:
    if verbose:
        print(f"Running {name}")

    # z-score input features
    if standardize:
        if verbose:
            print("--standardized")
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)
    else:
        X_train_scaled = X_train
        X_test_scaled = X_test

    # dimensionality reduction
    if pca_components is not None:
        if verbose:
            print("--reduced")
        pca = PCA(n_components=pca_components, random_state=42)
        X_train_red = pca.fit_transform(X_train_scaled)
        X_test_red = pca.transform(X_test_scaled)
    else:
        X_train_red = X_train_scaled
        X_test_red = X_test_scaled

    # fit survival model
    cox = CoxPHSurvivalAnalysis(alpha=0.1).fit(X_train_red, y_train)

    # generate predictions
    y_train_pred = cox.predict(X_train_red)
    y_test_pred = cox.predict(X_test_red)

    # evaluate predictions
    c_index = concordance_index_censored(
        event_indicator=y_test["Status"],
        event_time=y_test["Survival_in_days"],
        estimate=y_test_pred,
    )[0]

    return {
        "c_index": c_index,
        "y_test_pred": y_test_pred,
        "y_train_pred": y_train_pred,
    }

def run_unimodal_split(
    *,  # enforce kwargs
    X: np.ndarray,
    y: np.ndarray,
    test_idxs: np.ndarray,
    train_idxs: np.ndarray,
    pca_components: int or None,
    standardize: bool,
    name: str = "",
    verbose: bool = False,
) -> dict:
    # split matrices
    X_train, X_test = X[train_idxs], X[test_idxs]
    y_train, y_test = y[train_idxs], y[test_idxs]

    return run_split(
        X_train=X_train,
        y_train=y_train,
        X_test=X_test,
        y_test=y_test,
        pca_components=pca_components,
        standardize=standardize,
        name=name,
        verbose=verbose,
    )

def powerset(s):
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

In [22]:
def run_experiment(pca_components: int) -> dict:
    results = []
    for test_idxs in tqdm(test_splits, desc="Cross Validation Splits"):
        split_results = dict()

        temp = set(test_idxs)
        train_idxs = [i for i in range(n) if i not in temp]

        split_results["demo"] = run_unimodal_split(X=demo_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=None, standardize=False)
        split_results["canc"] = run_unimodal_split(X=canc_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=None, standardize=False)
        split_results["expr"] = run_unimodal_split(X=expr_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=pca_components, standardize=True)
        split_results["hist"] = run_unimodal_split(X=hist_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=pca_components, standardize=True)
        split_results["text"] = run_unimodal_split(X=text_X, y=y, test_idxs=test_idxs, train_idxs=train_idxs, pca_components=pca_components, standardize=True)

        y_train, y_test = y[train_idxs], y[test_idxs]

        combos = [sorted(x) for x in powerset(["demo", "canc", "expr", "hist", "text"]) if len(x) > 1]
        for combo in combos:
            mult_X_train = []
            mult_X_test = []
            for modality in combo:
                x_train = split_results[modality]["y_train_pred"][:, np.newaxis]
                x_test = split_results[modality]["y_test_pred"][:, np.newaxis]
                # z-score all unimodal risks
                scaler = StandardScaler()
                x_train = scaler.fit_transform(x_train)
                x_test = scaler.transform(x_test)
                mult_X_train.append(x_train)
                mult_X_test.append(x_test)

            mult_X_train = np.concatenate(mult_X_train, axis=1)
            mult_X_test = np.concatenate(mult_X_test, axis=1)

            split_results["-".join(combo)] = run_split(X_train=mult_X_train, y_train=y_train, X_test=mult_X_test, y_test=y_test, pca_components=None, standardize=False)

        results.append(split_results)
    return results

In [23]:
# run experiments for 256 components based on paper
results = dict()
pca_components = 256
results[pca_components] = run_experiment(pca_components=pca_components)
np.save(output_predictions, results)

Cross Validation Splits: 100%|██████████| 5/5 [03:38<00:00, 43.78s/it]


In [24]:
indiv_combos = ["canc", "demo", "expr", "hist", "text"]
fusion = '-'.join(indiv_combos)
print(fusion)

canc-demo-expr-hist-text


In [27]:
print("Unimodal Results")
print("------------------")
for pca_components in tqdm([256]):
    for combo in indiv_combos:
        c_idxs = []
        for i in range(5):
            c_idx = results[pca_components][i][combo]["c_index"]
            c_idxs.append(c_idx)
        c_idx = np.mean(c_idxs)
        print(f'{combo}:{c_idx}')

print("Multimodal Results")
print("------------------")
for pca_components in tqdm([256]):
    for combo in [fusion]:
        c_idxs = []
        for i in range(5):
            c_idx = results[pca_components][i][combo]["c_index"]
            c_idxs.append(c_idx)
        c_idx = np.mean(c_idxs)
        print(f'{combo}:{c_idx}')

Unimodal Results
------------------


100%|██████████| 1/1 [00:00<00:00, 3855.06it/s]


canc:0.7370195470669764
demo:0.6298858244261268
expr:0.7495909587116658
hist:0.7539801485317722
text:0.7519210591393961
Multimodal Results
------------------


100%|██████████| 1/1 [00:00<00:00, 8905.10it/s]

canc-demo-expr-hist-text:0.7932592520350699



