In [None]:
labels = ["Cardiomegaly", "Edema", "Consolidation", "Atelectasis", "Pleural Effusion"]

# TODO replace this entire cell with getting the embeddings and metadata from the API
# !wget "https://uchicago.box.com/shared/static/er0sz3yt17ke9zyg97ruhgemntuhx4r1.h5" -O "data/openi.h5"
# !wget "https://uchicago.box.com/shared/static/rasdb3b3xuirx7q4k012vnnx558px3we.csv" -O "data/openi.csv"

import h5py
import pandas as pd
import numpy as np

df = pd.read_csv("with-all-data/data/openi.csv")
X_emb = []
X_proj = []
with h5py.File("with-all-data/data/openi.h5", "r") as h5:
    for _, row in df.iterrows():
        sid = row["study_id"]
        pid = sid.replace("study", "patient")
        did = list(h5[f"img_embed/{pid}/{sid}"].keys())[0] # we know there's only 1 image per study in this h5 file, doesn't matter for API
        X_emb.append(h5[f"img_embed/{pid}/{sid}/{did}"][:])
        X_proj.append(h5[f"img_proj/{pid}/{sid}/{did}"][:])
X_emb = np.asarray(X_emb)
X_proj = np.asarray(X_proj)

In [None]:
from sklearn.base import clone
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pipe = Pipeline(
    [
        ("scale_raw", StandardScaler()),
    ],
)

modality_map = {
    "img_embed": clone(pipe).fit_transform(X_emb),
    "img_proj": clone(pipe).fit_transform(X_proj),
}

inputs = list(modality_map.keys())

In [None]:
from umap import UMAP
import matplotlib.pyplot as plt
import seaborn as sns
from ipywidgets import HBox, VBox, RadioButtons, Output, Label, Layout, HTML, interactive_output

controls = {
    "embedding_selector": RadioButtons(
        options=inputs,
        value="img_embed",
        description="Embedding Type:",
    ),
        "color_selector": RadioButtons(
        options=labels,
        value="Atelectasis",
        description="Color by:",
    ),
}

output_plot = Output()

spinner_html = """
<div style="display: flex; justify-content: center; align-items: center; width: 100%; height: 500px;">
    <div class="loader"></div>
</div>

<style>
.loader {
  border: 6px solid #f3f3f3;
  border-top: 6px solid #3498db;
  border-radius: 50%;
  width: 30px;
  height: 30px;
  animation: spin 1s linear infinite;
}

@keyframes spin {
  0% { transform: rotate(0deg); }
  100% { transform: rotate(360deg); }
}
</style>
"""

cached_modalities = dict()

def update(embedding_selector, color_selector):
    with output_plot:
        output_plot.clear_output(wait=False)
        display(HTML(spinner_html))
        output_plot.clear_output(wait=True)

        if embedding_selector == "img_embed":
            X = X_emb
        elif embedding_selector == "img_proj":
            X = X_proj
        else:
            raise ValueError(f"Unknown embedding: {embedding_selector}")

        if embedding_selector not in cached_modalities:
            cached_modalities[embedding_selector] = UMAP(n_neighbors=30, n_jobs=4).fit_transform(X)
        reduced = cached_modalities[embedding_selector]

        df["UMAP-1"] = reduced[:, 0]
        df["UMAP-2"] = reduced[:, 1]

        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 7))
        sns.scatterplot(df, x="UMAP-1" ,y="UMAP-2", hue=color_selector, s=10, ax=ax)
        sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), markerscale=2)
        plt.show()

interactive = interactive_output(update, controls)
checkbox_column = VBox(list(controls.values()), layout=Layout(width="150px"))
ui = HBox([checkbox_column, VBox([output_plot], layout=Layout(width="1000px"))])
display(ui, interactive)