In [None]:
!python3 -m src.train \
    --config /home/romainlhardy/code/hyperbolic-cancer/configs/lung/lung_e2h2s2.yaml

Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/romainlhardy/miniconda3/envs/mvae/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
[rank: 1] Seed set to 42
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mromain_hardy[0m ([33mromain_hardy-harvard-university[0m) to [

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import src.mvae.mt.mvae.utils as utils
import torch
import yaml

from scipy.io import mmread
from src.lightning.gene import GeneModule
from src.mvae.mt.data import GeneDataset
from src.mvae.mt.mvae.distributions import *
from src.mvae.mt.mvae.models.gene_vae import GeneVAE
from src.mvae.mt.mvae.ops.hyperbolics import lorentz_to_poincare
from src.mvae.mt.mvae.ops.spherical import spherical_to_projected
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [None]:
config_path = "/home/romainlhardy/code/hyperbolic-cancer/configs/lung/lung_e2.yaml"
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

dataset = GeneDataset(**config["data"]["options"])
print(dataset.n_gene_r)
print(dataset.n_gene_p)
print(dataset.n_batch)
print(len(dataset))

dataloader = DataLoader(dataset, batch_size=2048, num_workers=16, shuffle=True)

x_r, x_p, batch_idx = dataset[np.random.choice(len(dataset))]
print(x_r, x_p, batch_idx)
print(x_r.max())

In [None]:
checkpoint_path = "/home/romainlhardy/code/hyperbolic-cancer/models/mvae/lung_mvae_e2_epoch=999.ckpt"
# checkpoint_path = None

device = "cuda"
config["lightning"]["model"]["options"]["n_gene_r"] = dataset.n_gene_r
config["lightning"]["model"]["options"]["n_gene_p"] = dataset.n_gene_p
config["lightning"]["model"]["options"]["n_batch"] = dataset.n_batch
module = GeneModule(config).to(device)

if checkpoint_path is not None:
    module.load_state_dict(torch.load(checkpoint_path)["state_dict"])

model = module.model
model.eval()

x_r, x_p, batch_idx = next(iter(dataloader))
outputs = model(x_r.to(device), x_p.to(device), batch_idx.to(device))

r = outputs["reparametrized"][0]
q_z = r.q_z
p_z = r.p_z
p_samples = p_z.rsample(torch.Size([1000]))
q_samples = q_z.rsample(torch.Size([1000]))
print(p_samples.shape, q_samples.shape)
print(q_z.loc, q_z.scale)

In [None]:
def sphere_proj_2d(embeddings: np.ndarray) -> np.ndarray:
    x = embeddings[..., 0]
    y = embeddings[..., 1]
    z = embeddings[..., 2]
    coef = np.sqrt(2 / (1 - z)) / 2
    X = coef * x
    Y = coef * y
    return np.stack((X, Y), axis=-1)

def get_latents(reparametrized, num_components=1):
    assert len(reparametrized) > 0
    latents = [[] for _ in range(num_components)]
    for r in reparametrized:
        for i, rr in enumerate(r):
            latents[i].append(rr.q_z.loc.detach().cpu().numpy())
    for i in range(num_components):
        latents[i] = np.concatenate(latents[i], axis=0)
    for i in range(num_components):
        if isinstance(reparametrized[0][i].q_z, EuclideanNormal):
            continue
        elif isinstance(reparametrized[0][i].q_z, WrappedNormal):
            latents[i] = lorentz_to_poincare(torch.from_numpy(latents[i]), torch.tensor(1.0)).detach().cpu().numpy()
        elif isinstance(reparametrized[0][i].q_z, RadiusVonMisesFisher):
            latents[i] = sphere_proj_2d(torch.from_numpy(latents[i]).detach().cpu().numpy())
        else:
            raise ValueError()
    return latents

dataloader = DataLoader(dataset, batch_size=2048, num_workers=16, shuffle=False)

reparametrized = []
for batch in tqdm(dataloader):
    x_r, x_p, batch_idx = batch
    with torch.no_grad():
        outputs = model(x_r.to(device), x_p.to(device), batch_idx.to(device))
    reparametrized.append(outputs["reparametrized"])

num_components = len(model.components)
latents = get_latents(reparametrized, num_components)

In [None]:
cell_type_path = "/home/romainlhardy/code/hyperbolic-cancer/data/lung/metadata.tsv"
if cell_type_path is not None:
    cell_types = pd.read_csv(cell_type_path, sep="\t")["cell_type"].replace(np.nan, "Unknown").values
else:
    cell_types = np.ones((len(dataset),)) # Dummy cell types

unique_cell_types = np.unique(cell_types)
colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_cell_types)))
color_map = dict(zip(unique_cell_types, colors))
point_colors = [color_map[cell_type] for cell_type in cell_types]

filter = np.ones((len(dataset),), dtype=bool) # Dummy filter

fig, axs = plt.subplots(1, num_components, figsize=(num_components * 4, 4))
for i in range(num_components):
    ax = axs[i] if num_components > 1 else axs
    ax.scatter(latents[i][filter, 0], latents[i][filter, 1], s=0.5, alpha=0.6, c=np.array(point_colors)[filter])
    ax.set_title(str(model.components[i]))

plt.tight_layout()
plt.savefig(f"/home/romainlhardy/code/hyperbolic-cancer/figures/{config['experiment']}.png", dpi=200)
plt.show()

In [7]:
import numpy as np
import plotly.graph_objects as go
import os
import pandas as pd

from plotly.subplots import make_subplots


def create_sphere_surface(radius=0.99, resolution=100):
    phi = np.linspace(0, 2*np.pi, resolution)
    theta = np.linspace(-np.pi/2, np.pi/2, resolution)
    phi, theta = np.meshgrid(phi, theta)
    
    x = radius * np.cos(theta) * np.cos(phi)
    y = radius * np.cos(theta) * np.sin(phi)
    z = radius * np.sin(theta)
    
    return x, y, z


def load_cell_types(file_path):
    return pd.read_csv(file_path, header=None).iloc[:, 0].values


def create_3d_visualization(embeddings, cell_types, color_map, save_path):
    point_colors = [color_map[cell_type] for cell_type in cell_types]
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter3d(
        x=embeddings[:, 0],
        y=embeddings[:, 1],
        z=embeddings[:, 2],
        mode="markers",
        marker=dict(
            size=3,
            color=point_colors,
            opacity=0.6
        ),
        name="Cells"
    ))
    
    x, y, z = create_sphere_surface()
    fig.add_trace(go.Surface(
        x=x, y=y, z=z,
        opacity=1.0,
        showscale=False,
        colorscale=[[0, "#dbd7d2"], [1, "#dbd7d2"]],
        name="Sphere"
    ))
    
    for cell_type, color in color_map.items():
        fig.add_trace(go.Scatter3d(
            x=[None], y=[None], z=[None],
            mode="markers",
            marker=dict(size=8, color=color),
            name=cell_type
        ))
    
    fig.update_layout(
        title="scPhere 3D Visualization",
        width=1000,
        height=1000,
        scene=dict(
            xaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            yaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            zaxis=dict(showticklabels=False, showgrid=False, zeroline=False, showline=False),
            aspectmode="data"
        ),
        showlegend=False,
        legend=dict(
            yanchor="bottom",
            y=0.01,
            xanchor="center",
            x=0.5
        )
    )
    
    fig.write_html(save_path)

embeddings = s_latents
save_path = "/home/romainlhardy/code/hyperbolic-cancer/animations/lung_mvae.html"
create_3d_visualization(embeddings[filter], cell_types[filter], color_map, save_path)