In [None]:
labels = ["Cardiomegaly", "Edema", "Consolidation", "Atelectasis", "Pleural Effusion"]

# TODO replace this entire cell with getting the embeddings and metadata from the API
# !wget "https://uchicago.box.com/shared/static/er0sz3yt17ke9zyg97ruhgemntuhx4r1.h5" -O "data/openi.h5"
# !wget "https://uchicago.box.com/shared/static/rasdb3b3xuirx7q4k012vnnx558px3we.csv" -O "data/openi.csv"

import h5py
import pandas as pd
import numpy as np

df = pd.read_csv("with-all-data/data/openi.csv")
X_emb = []
X_proj = []
with h5py.File("with-all-data/data/openi.h5", "r") as h5:
    for _, row in df.iterrows():
        sid = row["study_id"]
        pid = sid.replace("study", "patient")
        did = list(h5[f"img_embed/{pid}/{sid}"].keys())[0] # we know there's only 1 image per study in this h5 file, doesn't matter for API
        X_emb.append(h5[f"img_embed/{pid}/{sid}/{did}"][:])
        X_proj.append(h5[f"img_proj/{pid}/{sid}/{did}"][:])
X_emb = np.asarray(X_emb)
X_proj = np.asarray(X_proj)

In [None]:
from sklearn.base import clone
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pipe = Pipeline(
    [
        ("scale_raw", StandardScaler()),
    ],
)

modality_map = {
    "img_embed": clone(pipe).fit_transform(X_emb),
    "img_proj": clone(pipe).fit_transform(X_proj),
}

inputs = list(modality_map.keys())

In [None]:
from umap import UMAP
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
import seaborn as sns
from ipywidgets import HBox, VBox, RadioButtons, Output, Layout, HTML, interactive_output

%matplotlib widget

controls = {
    "embedding_selector": RadioButtons(
        options=inputs,
        value="img_embed",
        description="Embedding Type:",
    ),
    "color_selector": RadioButtons(
        options=labels,
        value="Atelectasis",
        description="Color by:",
    ),
    "n_dims": RadioButtons(
        options=[2, 3],
        value=3,
        description="Number of Dimensions",
    ),
}

output_plot = Output()

spinner_html = """
<div style="display: flex; justify-content: center; align-items: center; width: 100%; height: 500px;">
    <div class="loader"></div>
</div>

<style>
.loader {
  border: 6px solid #f3f3f3;
  border-top: 6px solid #3498db;
  border-radius: 50%;
  width: 30px;
  height: 30px;
  animation: spin 1s linear infinite;
}

@keyframes spin {
  0% { transform: rotate(0deg); }
  100% { transform: rotate(360deg); }
}
</style>
"""

cached_modalities = dict()

def update(embedding_selector, color_selector, n_dims):
    with output_plot:
        output_plot.clear_output(wait=False)
        display(HTML(spinner_html))
        output_plot.clear_output(wait=True)
        plt.close("all")

        if embedding_selector == "img_embed":
            X = X_emb
        elif embedding_selector == "img_proj":
            X = X_proj
        else:
            raise ValueError(f"Unknown embedding: {embedding_selector}")

        embedding_selector += f"_{n_dims}d"

        if embedding_selector not in cached_modalities:
            cached_modalities[embedding_selector] = UMAP(n_components=n_dims, n_neighbors=30, n_jobs=4).fit_transform(X)
        reduced = cached_modalities[embedding_selector]

        temp = pd.Categorical(df[color_selector])
        codes = temp.codes
        categories = dict(enumerate(temp.categories))
        n_unique = codes.max() + 1
        if n_unique > 10:
            cmap = sns.color_palette("husl", n_unique)
            long_legend = True
        else:
            cmap = sns.color_palette("tab10")
            long_legend = False
        colors = np.asarray([cmap[i] for i in codes])
        legend_elements = []
        for i in range(n_unique):
            legend_elements.append(Line2D([0], [0], marker="o", ls="none", color=cmap[i], label=categories[i]))

        fig = plt.figure(figsize=(9, 7), num=" ")
        gs = GridSpec(1, 9, figure=fig)
        ax = fig.add_subplot(gs[0, 0:7], projection=None if n_dims == 2 else "3d")
        ax.set_xlabel("UMAP-1")
        ax.set_ylabel("UMAP-2")

        args = [reduced[:, 0], reduced[:, 1]]
        if n_dims > 2:
            args.append(reduced[:, 2])
            ax.set_zlabel("UMAP-3")

        ax.scatter(*args, c=colors, s=1)

        ax2 = fig.add_subplot(gs[0, 7:9])
        ax2.set_xticks([])
        ax2.set_yticks([])
        ax2.axis("off")
        title = " ".join(color_selector.split("_")).title().replace("Fu", "FU").replace("Id", "ID")
        ax2.legend(
            handles=legend_elements,
            loc="center",
            frameon=True,
            title=title,
            fontsize=8 if long_legend else 10,
        )

        fig.tight_layout()
        plt.show()

interactive = interactive_output(update, controls)
checkbox_column = VBox(list(controls.values()), layout=Layout(width="150px"))
ui = HBox([checkbox_column, VBox([output_plot], layout=Layout(width="1000px"))])
display(ui, interactive)