In [None]:
!python3 -m src.train \
    --config /home/romainlhardy/code/hyperbolic-cancer/configs/cellxgene/cellxgene_e6.yaml

In [None]:
!python3 -m src.train \
    --config /home/romainlhardy/code/hyperbolic-cancer/configs/cellxgene/cellxgene_e5h5s5.yaml

In [None]:
!python3 -m src.train \
    --config /home/romainlhardy/code/hyperbolic-cancer/configs/cellxgene/cellxgene_s15.yaml

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import plotly.graph_objects as go
import src.mvae.mt.mvae.utils as utils
import torch
import yaml

from scipy.io import mmread
from src.lightning.gene import GeneModule
from src.mvae.mt.data import GeneDataset
from src.mvae.mt.mvae.distributions import *
from src.mvae.mt.mvae.models.gene_vae import GeneVAE
from src.mvae.mt.mvae.ops.hyperbolics import lorentz_to_poincare
from src.mvae.mt.mvae.ops.spherical import spherical_to_projected
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [None]:
config_path = "/home/romainlhardy/code/hyperbolic-cancer/configs/lung/lung_e2h2s2.yaml"
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

dataset = GeneDataset(**config["data"]["options"])
print(dataset.n_gene)
print(dataset.n_batch)
print(len(dataset))

dataloader = DataLoader(dataset, batch_size=256, num_workers=16, shuffle=True)

x, batch_idx = dataset[np.random.choice(len(dataset))]
print(x, batch_idx)
print(x.max())

In [None]:
checkpoint_path = "/home/romainlhardy/code/hyperbolic-cancer/models/mvae/lung_mvae_e2h2s2.ckpt"
# checkpoint_path = None

device = "cuda"
config["lightning"]["model"]["options"]["n_gene"] = dataset.n_gene
config["lightning"]["model"]["options"]["n_batch"] = dataset.n_batch
module = GeneModule(config).to(device)

if checkpoint_path is not None:
    module.load_state_dict(torch.load(checkpoint_path)["state_dict"])

model = module.model
model.eval()

x, batch_idx = next(iter(dataloader))
outputs = model(x.to(device), batch_idx.to(device))

In [None]:
def sphere_proj_2d(embeddings: np.ndarray) -> np.ndarray:
    x = embeddings[..., 0]
    y = embeddings[..., 1]
    z = embeddings[..., 2]
    coef = np.sqrt(2 / (1 - z)) / 2
    X = coef * x
    Y = coef * y
    return np.stack((X, Y), axis=-1)

def get_latents(reparametrized, num_components=1):
    assert len(reparametrized) > 0
    latents = [[] for _ in range(num_components)]
    for r in reparametrized:
        for i, rr in enumerate(r):
            latents[i].append(rr.q_z.loc.detach().cpu().numpy())
    for i in range(num_components):
        latents[i] = np.concatenate(latents[i], axis=0)
    for i in range(num_components):
        if isinstance(reparametrized[0][i].q_z, EuclideanNormal):
            continue
        elif isinstance(reparametrized[0][i].q_z, WrappedNormal):
            latents[i] = lorentz_to_poincare(torch.from_numpy(latents[i]), torch.tensor(1.0)).detach().cpu().numpy()
        elif isinstance(reparametrized[0][i].q_z, RadiusVonMisesFisher):
            latents[i] = sphere_proj_2d(torch.from_numpy(latents[i]).detach().cpu().numpy())
        else:
            raise ValueError()
    return latents

dataloader = DataLoader(dataset, batch_size=2048, num_workers=16, shuffle=False)

reparametrized = []
for batch in tqdm(dataloader):
    x, batch_idx = batch
    with torch.no_grad():
        outputs = model(x.to(device), batch_idx.to(device))
    reparametrized.append(outputs["reparametrized"])

num_components = len(model.components)
latents = get_latents(reparametrized, num_components)

In [None]:
# Cell types
metadata_path = "/home/romainlhardy/code/hyperbolic-cancer/data/lung/metadata.tsv"
key = "cell_type"
if metadata_path is not None:
    categories = pd.read_csv(metadata_path, sep="\t")[key].replace(np.nan, "Unknown").values
else:
    categories = np.array(["Unknown"] * len(dataset)) # Ensure it's an array and handle the else case appropriately

filter = categories != "Unknown"

unique_categories_all = np.unique(categories)
unique_categories_filtered = np.unique(categories[filter])

colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_categories_all)))
color_map = dict(zip(unique_categories_all, colors))
point_colors = np.array([color_map[category] for category in categories])

fig, axs = plt.subplots(1, num_components, figsize=(num_components * 4, 4))
for i in range(num_components):
    ax = axs[i] if num_components > 1 else axs
    ax.scatter(latents[i][filter, 0], latents[i][filter, 1], s=1.0, alpha=0.6, c=point_colors[filter])
    ax.set_title("")
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
main_plot_path = f"/home/romainlhardy/code/hyperbolic-cancer/figures/{config['experiment']}_{key}.png"
plt.savefig(main_plot_path, dpi=200, bbox_inches="tight")
plt.show()

# Legend
legend_fig, legend_ax = plt.subplots(figsize=(3, max(2, len(unique_categories_filtered) * 0.3)))

legend_handles = []
for category in unique_categories_filtered:
    handle = legend_ax.scatter([], [], color=color_map[category], label=category, s=50)
    legend_handles.append(handle)

legend_ax.legend(handles=legend_handles, title=key.replace("_", " ").title(), loc="center", frameon=False)

legend_ax.axis("off")

legend_plot_path = f"/home/romainlhardy/code/hyperbolic-cancer/figures/{config['experiment']}_{key}_legend.png"
legend_fig.tight_layout()
legend_fig.savefig(legend_plot_path, dpi=200, bbox_inches="tight")
plt.show()

In [None]:
# Tumor stage
metadata_path = "/home/romainlhardy/code/hyperbolic-cancer/data/lung/metadata.tsv"
key = "mp_assignment"
legend_title = "Metastatic Progression"
top_n = 10
remove_first_n = 1

if metadata_path is not None:
    categories = pd.read_csv(metadata_path, sep="\t")[key].replace(np.nan, "Unknown")
    top_categories = categories.value_counts().head(top_n).index
    print(top_categories)
    categories = categories.apply(lambda x: x if x in top_categories[remove_first_n:] else "Other").values
else:
    categories = np.array(["Other"] * len(dataset)) # Ensure it's an array and handle the else case appropriately

filter = categories != "Other"

unique_categories_all = np.unique(categories)
unique_categories_filtered = np.unique(categories[filter])

colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_categories_all)))
color_map = dict(zip(unique_categories_all, colors))
point_colors = np.array([color_map[category] for category in categories])

fig, axs = plt.subplots(1, num_components, figsize=(num_components * 4, 4))
for i in range(num_components):
    ax = axs[i] if num_components > 1 else axs
    ax.scatter(latents[i][filter, 0], latents[i][filter, 1], s=1.0, alpha=0.6, c=point_colors[filter])
    ax.set_title("")
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
main_plot_path = f"/home/romainlhardy/code/hyperbolic-cancer/figures/{config['experiment']}_{key}.png"
plt.savefig(main_plot_path, dpi=200, bbox_inches="tight")
plt.show()

# Legend
legend_fig, legend_ax = plt.subplots(figsize=(3, max(2, len(unique_categories_filtered) * 0.3)))

legend_handles = []
for category in unique_categories_filtered:
    handle = legend_ax.scatter([], [], color=color_map[category], label=category, s=50)
    legend_handles.append(handle)

legend_ax.legend(handles=legend_handles, title=legend_title, loc="center", frameon=False)

legend_ax.axis("off")

legend_plot_path = f"/home/romainlhardy/code/hyperbolic-cancer/figures/{config['experiment']}_{key}_legend.png"
legend_fig.tight_layout()
legend_fig.savefig(legend_plot_path, dpi=200, bbox_inches="tight")
plt.show()

In [9]:
# 3D sphere visualization
def create_sphere_surface(radius=0.99, resolution=50):
    phi = np.linspace(0, 2 * np.pi, resolution)
    theta = np.linspace(-np.pi / 2, np.pi / 2, resolution)
    phi, theta = np.meshgrid(phi, theta)
    
    x = radius * np.cos(theta) * np.cos(phi)
    y = radius * np.cos(theta) * np.sin(phi)
    z = radius * np.sin(theta)
    
    return x, y, z


def load_cell_types(file_path):
    return pd.read_csv(file_path, header=None).iloc[:, 0].values


def create_3d_sphere_visualization(embeddings, cell_types, color_map, save_path):
    point_colors = [color_map[cell_type] for cell_type in cell_types]
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter3d(
        x=embeddings[:, 0],
        y=embeddings[:, 1],
        z=embeddings[:, 2],
        mode="markers",
        marker=dict(
            size=3,
            color=point_colors,
            opacity=0.6
        ),
        text=cell_types,
        hovertemplate="%{text}<extra></extra>",
        name="Cells"
    ))
    
    x, y, z = create_sphere_surface()
    fig.add_trace(go.Surface(
        x=x, y=y, z=z,
        opacity=1.0,
        showscale=False,
        colorscale=[[0, "#dbd7d2"], [1, "#dbd7d2"]],
        name="Sphere"
    ))
    
    for cell_type, color in color_map.items():
        fig.add_trace(go.Scatter3d(
            x=[None], y=[None], z=[None],
            mode="markers",
            marker=dict(size=8, color=color),
            name=cell_type
        ))
    
    fig.update_layout(
        title="Sphere 3D Visualization",
        width=1000,
        height=1000,
        scene=dict(
            xaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            yaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            zaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            aspectmode="data"
        ),
        showlegend=False
    )
    
    fig.write_html(save_path)

embeddings = latents[-1]
save_path = f"/home/romainlhardy/code/hyperbolic-cancer/animations/{config['experiment']}_{key}_sphere.html"
create_3d_sphere_visualization(embeddings[filter], categories[filter], color_map, save_path)

In [10]:
# 3D hyperboloid visualization
def create_hyperboloid_surface(radius=2, height_limit=5, num_points=50):
    x_range = np.linspace(-radius, radius, num_points)
    y_range = np.linspace(-radius, radius, num_points)
    x_grid, y_grid = np.meshgrid(x_range, y_range)

    z_grid = np.sqrt(1 + x_grid ** 2 + y_grid ** 2)

    z_grid[z_grid > height_limit] = np.nan

    return x_grid, y_grid, z_grid


def create_3d_hyperboloid_visualization(embeddings, cell_types, color_map, save_path):
    point_colors = [color_map[cell_type] for cell_type in cell_types]

    fig = go.Figure()

    # Plot the cell embeddings
    fig.add_trace(go.Scatter3d(
        x=embeddings[:, 1],
        y=embeddings[:, 2],
        z=embeddings[:, 0],
        mode="markers",
        marker=dict(
            size=3,
            color=point_colors,
            opacity=0.6
        ),
        text=cell_types,
        hovertemplate="%{text}<extra></extra>",
        name="Cells"
    ))

    hx, hy, hz = create_hyperboloid_surface(radius=3, height_limit=4)
    fig.add_trace(go.Surface(
        x=hx, y=hy, z=hz,
        opacity=0.3,
        showscale=False,
        colorscale=[[0, "#dbd7d2"], [1, "#dbd7d2"]],
        name="Hyperboloid"
    ))

    for cell_type, color in color_map.items():
        fig.add_trace(go.Scatter3d(
            x=[None], y=[None], z=[None],
            mode="markers",
            marker=dict(size=8, color=color),
            name=cell_type
        ))

    fig.update_layout(
        title="Hyperboloid 3D Visualization",
        width=1000,
        height=1000,
        scene=dict(
            xaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            yaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            zaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            aspectmode="data"
        ),
        showlegend=False
    )

    fig.write_html(save_path)

embeddings = latents[1]
save_path = f"/home/romainlhardy/code/hyperbolic-cancer/animations/{config['experiment']}_{key}_hyperboloid.html"
create_3d_hyperboloid_visualization(embeddings[filter], categories[filter], color_map, save_path)