In [None]:
!python3 -m src.train \
    --config /home/romainlhardy/code/hyperbolic-cancer/configs/lung/lung_s2.yaml

In [None]:
!python3 -m src.train \
    --config /home/romainlhardy/code/hyperbolic-cancer/configs/lung/lung_e2h2s2.yaml

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import src.mvae.mt.mvae.utils as utils
import torch
import yaml

from scipy.io import mmread
from src.lightning.gene import GeneModule
from src.mvae.mt.data import GeneDataset
from src.mvae.mt.mvae.models.gene_vae import GeneVAE
from src.mvae.mt.mvae.ops.hyperbolics import lorentz_to_poincare
from src.mvae.mt.mvae.ops.spherical import spherical_to_projected
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [None]:
config_path = "/home/romainlhardy/code/hyperbolic-cancer/configs/lung/lung_e2h2s2.yaml"
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

dataset = GeneDataset(**config["data"]["options"])
print(dataset.n_gene)
print(dataset.n_batch)
print(len(dataset))

dataloader = DataLoader(dataset, batch_size=2048, num_workers=16, shuffle=True)

x, batch_idx = dataset[np.random.choice(len(dataset))]
print(x, batch_idx)
print(x.max())

In [None]:
# checkpoint_path = "/home/romainlhardy/code/hyperbolic-cancer/models/mvae/lung_mvae_epoch=499.ckpt"
checkpoint_path = None

device = "cuda"
config["lightning"]["model"]["options"]["n_gene"] = dataset.n_gene
config["lightning"]["model"]["options"]["n_batch"] = dataset.n_batch
module = GeneModule(config).to(device)

if checkpoint_path is not None:
    module.load_state_dict(torch.load(checkpoint_path)["state_dict"])

model = module.model
model.eval()

x, batch_idx = next(iter(dataloader))
outputs = model(x.to(device), batch_idx.to(device))

r = outputs["reparametrized"][0]
q_z = r.q_z
p_z = r.p_z
p_samples = p_z.rsample(torch.Size([1000]))
q_samples = q_z.rsample(torch.Size([1000]))
print(p_samples.shape, q_samples.shape)
print(q_z.loc, q_z.scale)

In [None]:
def sphere_proj_2d(embeddings: np.ndarray) -> np.ndarray:
    x = embeddings[..., 0]
    y = embeddings[..., 1]
    z = embeddings[..., 2]
    coef = np.sqrt(2 / (1 - z)) / 2
    X = coef * x
    Y = coef * y
    return np.stack((X, Y), axis=-1)

dataloader = DataLoader(dataset, batch_size=2048, num_workers=16, shuffle=False)

reparametrized = []
for batch in tqdm(dataloader):
    x, batch_idx = batch
    with torch.no_grad():
        outputs = model(x.to(device), batch_idx.to(device))
    reparametrized.append(outputs["reparametrized"])

e_latents = torch.cat([r[0].q_z.loc for r in reparametrized], dim=0).detach().cpu().numpy()
h_latents = torch.cat([r[1].q_z.loc for r in reparametrized], dim=0).detach().cpu().numpy()
s_latents = torch.cat([r[2].q_z.loc for r in reparametrized], dim=0).detach().cpu().numpy()
# s_latents = torch.cat([r[0].q_z.loc for r in reparametrized], dim=0).detach().cpu().numpy()

hp_latents = lorentz_to_poincare(torch.from_numpy(h_latents), torch.tensor(1.0)).detach().cpu().numpy()
sp_latents = sphere_proj_2d(torch.from_numpy(s_latents).detach().cpu().numpy())

In [None]:
cell_type_path = "/home/romainlhardy/code/hyperbolic-cancer/data/lung/metadata.tsv"
# cell_type_path = None
if cell_type_path is not None:
    cell_types = pd.read_csv(cell_type_path, sep="\t")["cell_type"].replace(np.nan, "Unknown").values
    # cell_types = pd.read_csv(cell_type_path, sep="\t", header=None).iloc[:, 0].values
else:
    cell_types = np.ones((len(s_latents),)) # Dummy cell types

unique_cell_types = np.unique(cell_types)
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_cell_types)))
color_map = dict(zip(unique_cell_types, colors))
point_colors = [color_map[cell_type] for cell_type in cell_types]

filter = np.ones((len(s_latents),), dtype=bool) # Dummy filter

fig, axs = plt.subplots(1, 3, figsize=(12, 4))
# axs[2].scatter(s_latents[filter, 0], s_latents[filter, 1], s=0.5, alpha=0.6)
# axs[2].scatter(s_latents[filter, 0], s_latents[filter, 1], s=0.5, alpha=0.6, c=np.array(point_colors)[filter])
axs[0].scatter(e_latents[filter, 0], e_latents[filter, 1], s=0.5, alpha=0.6, c=np.array(point_colors)[filter])
axs[1].scatter(hp_latents[filter, 0], hp_latents[filter, 1], s=0.5, alpha=0.6, c=np.array(point_colors)[filter])
axs[2].scatter(sp_latents[filter, 0], sp_latents[filter, 1], s=0.5, alpha=0.6, c=np.array(point_colors)[filter])

axs[0].set_title("Euclidean Space")
axs[1].set_title("Hyperbolic Space (Poincaré Disk)")
axs[2].set_title("Spherical Space")

plt.tight_layout()
plt.savefig("/home/romainlhardy/code/hyperbolic-cancer/figures/lung_mvae.png", dpi=200)
plt.show()

In [7]:
import numpy as np
import plotly.graph_objects as go
import os
import pandas as pd

from plotly.subplots import make_subplots


def create_sphere_surface(radius=0.99, resolution=100):
    phi = np.linspace(0, 2*np.pi, resolution)
    theta = np.linspace(-np.pi/2, np.pi/2, resolution)
    phi, theta = np.meshgrid(phi, theta)
    
    x = radius * np.cos(theta) * np.cos(phi)
    y = radius * np.cos(theta) * np.sin(phi)
    z = radius * np.sin(theta)
    
    return x, y, z


def load_cell_types(file_path):
    return pd.read_csv(file_path, header=None).iloc[:, 0].values


def create_3d_visualization(embeddings, cell_types, color_map, save_path):
    point_colors = [color_map[cell_type] for cell_type in cell_types]
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter3d(
        x=embeddings[:, 0],
        y=embeddings[:, 1],
        z=embeddings[:, 2],
        mode="markers",
        marker=dict(
            size=3,
            color=point_colors,
            opacity=0.6
        ),
        name="Cells"
    ))
    
    x, y, z = create_sphere_surface()
    fig.add_trace(go.Surface(
        x=x, y=y, z=z,
        opacity=1.0,
        showscale=False,
        colorscale=[[0, "#dbd7d2"], [1, "#dbd7d2"]],
        name="Sphere"
    ))
    
    for cell_type, color in color_map.items():
        fig.add_trace(go.Scatter3d(
            x=[None], y=[None], z=[None],
            mode="markers",
            marker=dict(size=8, color=color),
            name=cell_type
        ))
    
    fig.update_layout(
        title="scPhere 3D Visualization",
        width=1000,
        height=1000,
        scene=dict(
            xaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            yaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            zaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            aspectmode="data"
        ),
        showlegend=False,
        legend=dict(
            yanchor="bottom",
            y=0.01,
            xanchor="center",
            x=0.5
        )
    )
    
    fig.write_html(save_path)

embeddings = s_latents
save_path = "/home/romainlhardy/code/hyperbolic-cancer/animations/lung_mvae.html"
create_3d_visualization(embeddings[filter], cell_types[filter], color_map, save_path)