# TCGA Cancer Embedding Visualization Notebook

## Pull Embeddings

In [None]:
# TODO replace this entire cell with getting the embeddings from the API
# !wget "https://uchicago.box.com/shared/static/k8z0kip2pej2v62pwgymdm45gt7dq0yw.h5" -O "data/hist.h5"
# !wget "https://uchicago.box.com/shared/static/hr82b5c9g3h4y8c7avrnbgvhgdcoetld.h5" -O "data/expr.h5"
# !wget "https://uchicago.box.com/shared/static/liwt3vlvdpmbfsa21wqboshh9nv6enm2.h5" -O "data/summ.h5"

import h5py
from tqdm import tqdm

expr_file = "with-all-data/data/expr.h5" # BulkRNABert
hist_file = "with-all-data/data/hist.h5" # UNI2
text_file = "with-all-data/data/summ.h5" # BioMistral - Summarized

expr_embs = dict()
hist_embs = dict()
text_embs = dict()

for fpath, embs, do_upper in tqdm([
    (expr_file, expr_embs, False),
    (hist_file, hist_embs, False),
    (text_file, text_embs, True),
]):
    with h5py.File(fpath, "r") as h5:
        for case_id in h5.keys():
            for sample_fname in h5[case_id].keys():
                emb_key = sample_fname
                if do_upper:
                    emb_key = emb_key.upper()
                embs[emb_key] = h5[case_id][sample_fname][:]

## Pull and Join Metadata

In [None]:
import json
import requests
from io import StringIO

import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from jsonrpcclient import request, parse_json, Error

# helper to parse the SSE-formatted response from the MCP server
def parse_sse(response):
    x = parse_json(response.text.split("\n", 1)[1].replace("data: ", "", 1))
    if isinstance(x, Error):
        raise ValueError(x.message)
    return x.result

# helper to use the MCP tool to construct a GDC cohort filter
def get_cohort_filter(query: str) -> str:
    return parse_sse(requests.post(
        "https://apps.m3aicommons.org/gdc-cohort-copilot/gradio_api/mcp/",
        json=request(
            method="tools/call",
            params={"name": "generate_filter", "arguments": {"query": query}},
        ),
        headers={"Accept": "application/json, text/event-stream"},
    ))["content"][0]["text"]

# helper to convert GDC request to file mapping
def get_file_mapping(cohort_filter):
    df = pd.read_csv(StringIO(requests.get(
        "https://api.gdc.cancer.gov/files",
        params={
            "filters": cohort_filter, "format": "TSV", "size": str(100_000),
            "fields": ",".join(["file_name", "cases.submitter_id"]),                
        },
    ).text), sep="\t")
    df = df[["cases.0.submitter_id", "file_name"]]
    df = df.rename(columns={"cases.0.submitter_id": "case_id"})
    df["file_name"] = df["file_name"].str.replace(".svs", "").str.replace(".rna_seq.augmented_star_gene_counts.tsv", "").str.replace(".PDF", "")
    return df.set_index("case_id")["file_name"]

# Construct some cohort filters, you can print these to see what they look like!
hist_filter = get_cohort_filter("TCGA samples with diagnostic slide svs files")
expr_filter = get_cohort_filter("TCGA samples with RNA-sequences gene expression quantification data as tsv files")
text_filter = get_cohort_filter("TCGA samples with pathology report pdf files")

# Query the API using the filters to get a mapping of case ID --> filename (from the embedding service)
hist_mapping = get_file_mapping(hist_filter)
expr_mapping = get_file_mapping(expr_filter)
text_mapping = get_file_mapping(text_filter)

# Get clinical metadata, this is standard code to pull data from the GDC
# (though we use an extra call to the GDC Cohort Copilot for fun)
response = requests.get(
    "https://api.gdc.cancer.gov/cases",
    params={
        "filters": get_cohort_filter("samples from TCGA"), "format": "JSON", "size": str(100_000),
        "fields": ",".join(
            [
                "project.project_id",
                "submitter_id",
                "diagnoses.age_at_diagnosis",
                "diagnoses.diagnosis_is_primary_disease",
                "diagnoses.days_to_last_follow_up",
                "demographic.days_to_death",
                "demographic.vital_status",
                "demographic.ethnicity",
                "demographic.gender",
                "demographic.race",
            ]
        ),
    },
)
temp = json.loads(response.content)

data = []
for sample in temp["data"]["hits"]:
    if "diagnoses" not in sample:
        continue
    for diagnosis in sample["diagnoses"]:
        if "diagnosis_is_primary_disease" in diagnosis and diagnosis["diagnosis_is_primary_disease"]:
            break
    else:
        # no primary disease, skip
        continue
    if "days_to_last_follow_up" not in diagnosis:
        # primary diagnosis does not have f/u
        continue
    data.append({
        "project_id": sample["project"]["project_id"],
        "case_id": sample["submitter_id"],
        "age_at_diagnosis": diagnosis["age_at_diagnosis"],
        "days_to_last_fu": diagnosis["days_to_last_follow_up"],
        "sex": sample["demographic"]["gender"],
        "race": sample["demographic"]["race"],
        "ethnicity": sample["demographic"]["ethnicity"],
        "vital_status": sample["demographic"]["vital_status"],
    })

case_ids = ( # case IDs that have all 3 embedding modalities
    set(expr_mapping[expr_mapping.isin(expr_embs)].index) &
    set(hist_mapping[hist_mapping.isin(hist_embs)].index) &
    set(text_mapping[text_mapping.isin(text_embs)].index)
)

df = pd.DataFrame(data)
df = df[
    df["case_id"].isin(case_ids)
    & df["vital_status"].isin({"Alive", "Dead"})
    & df["age_at_diagnosis"].notna()
    & df["days_to_last_fu"].notna()
    & (df["days_to_last_fu"] >= 0)
]
df["age_at_diagnosis"] /= 365.24 # convert age to years
df["age_binned"] = pd.cut(df["age_at_diagnosis"], bins=np.arange(0, 101, 20))

df = df.reset_index(drop=True)

expr_X = []
hist_X = []
text_X = []
for _, row in df.iterrows():
    case_id = row["case_id"]
    for embs, mapping, X in [
        (expr_embs, expr_mapping, expr_X),
        (hist_embs, hist_mapping, hist_X),
        (text_embs, text_mapping, text_X),
    ]:
        fnames = mapping.loc[case_id]
        if isinstance(fnames, str):
            emb = embs[fnames] # fname is guaranteed to be in embeddings
        else:
            # at least one fname will be in embs
            emb = np.mean([embs[f] for f in fnames if f in embs], axis=0)
        X.append(emb)
expr_X = np.asarray(expr_X)
hist_X = np.asarray(hist_X)
text_X = np.asarray(text_X)

## Prepare Data for Visualization

In [None]:
from sklearn.base import clone
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

pipe = Pipeline(
    [
        ("scale_raw", StandardScaler()),
        ("pca_reduce", PCA(32, random_state=42)), # reduce dimensionality so UMAP runs faster
        ("scale_pca", StandardScaler()), # normalize PCA output
    ],
)

modality_map = {
    "expr": clone(pipe).fit_transform(expr_X),
    "hist": clone(pipe).fit_transform(hist_X),
    "text": clone(pipe).fit_transform(text_X),
}

output_types = {
    "project_id": "cat",
    "age_binned": "cat",
    "sex": "cat",
    "race": "cat",
    "ethnicity": "cat",
    "vital_status": "cat",
    "age_at_diagnosis": "cont",
    "days_to_last_fu": "cont",
}

inputs = list(modality_map.keys())
cat_outputs = [k for k, v in output_types.items() if v == "cat"]
cont_outputs = [k for k, v in output_types.items() if v == "cont"]

## Visualization Widget

In [None]:
from umap import UMAP
import matplotlib.pyplot as plt
import seaborn as sns
from ipywidgets import Checkbox, HBox, VBox, RadioButtons, Output, Label, Layout, HTML, interactive_output

controls = {
    name: Checkbox(value=i==0, description=name, indent=False)
    for i, name in enumerate(inputs)
}

controls["color_selector"] = RadioButtons(
    options=cat_outputs + cont_outputs,
    value="project_id",
    description="Color by:",
)

output_plot = Output()

spinner_html = """
<div style="display: flex; justify-content: center; align-items: center; width: 100%; height: 500px;">
    <div class="loader"></div>
</div>

<style>
.loader {
  border: 6px solid #f3f3f3;
  border-top: 6px solid #3498db;
  border-radius: 50%;
  width: 30px;
  height: 30px;
  animation: spin 1s linear infinite;
}

@keyframes spin {
  0% { transform: rotate(0deg); }
  100% { transform: rotate(360deg); }
}
</style>
"""

cached_modalities = dict()

def update(color_selector, **modality_selections):
    with output_plot:
        output_plot.clear_output(wait=False)
        display(HTML(spinner_html))
        output_plot.clear_output(wait=True)

        modalities = [name for name, checked in modality_selections.items() if checked]
        modalities = sorted(modalities)
        modalities_str = ", ".join(modalities)

        if modalities_str == "":
            print("Please select at least 1 modality")
            return

        if modalities_str not in cached_modalities:
            X = np.concatenate([modality_map[modality] for modality in modalities], axis=1)
            cached_modalities[modalities_str] = UMAP(n_neighbors=30, n_jobs=4).fit_transform(X)
        reduced = cached_modalities[modalities_str]

        df["UMAP-1"] = reduced[:, 0]
        df["UMAP-2"] = reduced[:, 1]
        hue_col = color_selector.replace("_", " ").title().replace("Fu", "FU").replace("Id", "ID").replace("At", "at")
        output_type = output_types[color_selector]
        kwargs = dict()
        if output_type == "cont":
            kwargs["palette"] = "viridis"
        df[hue_col] = pd.Categorical(df[color_selector]) if output_type == "cat" else df[color_selector]

        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 7))
        sns.scatterplot(df, x="UMAP-1" ,y="UMAP-2", hue=hue_col, s=10, ax=ax, **kwargs)
        sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), markerscale=2)
        plt.show()

interactive = interactive_output(update, controls)
checkbox_column = VBox([Label("Embedding Modality:")] + list(controls.values()), layout=Layout(width="150px"))
ui = HBox([checkbox_column, VBox([output_plot], layout=Layout(width="1000px"))])
display(ui, interactive)