In [1]:
import cupy as cp
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import src.mvae.mt.mvae.utils as utils
import torch
import yaml

from functools import partial
from scipy import stats
from scipy.io import mmread
from src.lightning.gene import GeneModule
from src.mvae.mt.data import GeneDataset
from src.mvae.mt.mvae.components import *
from src.mvae.mt.mvae.distributions import *
from src.mvae.mt.mvae.models.gene_vae import GeneVAE
from src.mvae.mt.mvae.ops.hyperbolics import lorentz_to_poincare
from src.mvae.mt.mvae.ops.spherical import spherical_to_projected
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [None]:
config_path = "/home/romainlhardy/code/hyperbolic-cancer/configs/lung/lung_e2.yaml"
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

dataset = GeneDataset(**config["data"]["options"])
print(dataset.n_gene_r)
print(dataset.n_gene_p)
print(dataset.n_batch)
print(len(dataset))

dataloader = DataLoader(dataset, batch_size=2048, num_workers=16, shuffle=True)

x_r, x_p, batch_idx = dataset[np.random.choice(len(dataset))]
print(x_r, x_p, batch_idx)
print(x_r.max())

In [None]:
checkpoint_path = "/home/romainlhardy/code/hyperbolic-cancer/models/mvae/lung_mvae_e2.ckpt"
# checkpoint_path = None

device = "cuda"
config["lightning"]["model"]["options"]["n_gene_r"] = dataset.n_gene_r
config["lightning"]["model"]["options"]["n_gene_p"] = dataset.n_gene_p
config["lightning"]["model"]["options"]["n_batch"] = dataset.n_batch
module = GeneModule(config).to(device)

if checkpoint_path is not None:
    module.load_state_dict(torch.load(checkpoint_path)["state_dict"])

model = module.model
model.eval()

x_r, x_p, batch_idx = next(iter(dataloader))
outputs = model(x_r.to(device), x_p.to(device), batch_idx.to(device))

r = outputs["reparametrized"][0]
q_z = r.q_z
p_z = r.p_z
p_samples = p_z.rsample(torch.Size([1000]))
q_samples = q_z.rsample(torch.Size([1000]))
print(p_samples.shape, q_samples.shape)
print(q_z.loc, q_z.scale)

In [7]:
def pairwise_distances_euclidean(src: cp.ndarray, dst: cp.ndarray) -> cp.ndarray:
    assert src.ndim == 2 and dst.ndim == 2 and src.shape[1] == dst.shape[1]
    diff = src[:, None, :] - dst[None, :, :]
    return cp.linalg.norm(diff, axis=-1)


def pairwise_distances_spherical(src: cp.ndarray, dst: cp.ndarray, radius: float = 1.0) -> cp.ndarray:
    assert src.ndim == 2 and dst.ndim == 2 and src.shape[1] == dst.shape[1]
    dots = src @ dst.T
    cos_theta = cp.clip(dots / (radius ** 2), -1.0, 1.0)
    return radius * cp.arccos(cos_theta)


def pairwise_distances_hyperboloid(src: cp.ndarray, dst: cp.ndarray, radius: float = 1.0) -> cp.ndarray:
    assert src.ndim == 2 and dst.ndim == 2 and src.shape[1] == dst.shape[1]
    timelike = cp.outer(src[:, 0], dst[:, 0])
    spatial = src[:, 1:] @ dst[:, 1:].T
    dots = timelike - spatial
    cosh_arg = cp.clip(dots / (radius ** 2), 1.0, None)
    return radius * cp.arccosh(cosh_arg)

In [None]:
def get_latents(reparametrized, num_components=1):
    assert len(reparametrized) > 0

    latents = [[] for _ in range(num_components)]
    for r in reparametrized:
        for i, rr in enumerate(r):
            latents[i].append(rr.q_z.loc.detach().cpu().numpy())

    for i in range(num_components):
        latents[i] = np.concatenate(latents[i], axis=0)
        
    return latents

dataloader = DataLoader(dataset, batch_size=2048, num_workers=16, shuffle=False)

reparametrized = []
for batch in tqdm(dataloader):
    x_r, x_p, batch_idx = batch
    with torch.no_grad():
        outputs = model(x_r.to(device), x_p.to(device), batch_idx.to(device))
    reparametrized.append(outputs["reparametrized"])

num_components = len(model.components)
latents = get_latents(reparametrized, num_components)

In [None]:
def get_distance_fns(model):
    distance_fns = []
    for component in model.components:
        if isinstance(component, EuclideanComponent):
            distance_fns.append(pairwise_distances_euclidean)
        elif isinstance(component, SphericalComponent):
            distance_fns.append(partial(pairwise_distances_spherical, radius=component.manifold.radius))
        elif isinstance(component, HyperbolicComponent):
            distance_fns.append(partial(pairwise_distances_hyperboloid, radius=component.manifold.radius))
        else:
            raise ValueError()
    return distance_fns


def knn_accuracy(latents, labels, distance_fns, neighbors=15, batch_size=16):
    num_samples = len(latents[0])
    assert num_samples == len(labels)
    assert len(latents) == len(distance_fns)

    correct = 0
    total = 0

    for i in tqdm(range(0, num_samples, batch_size), total=(num_samples + batch_size - 1) // batch_size):
        j = min(i + batch_size, num_samples)
        batch_points = [l[i : j] for l in latents]
        batch_labels = labels[i : j]

        squared_distances = []
        for k, distance_fn in enumerate(distance_fns):
            squared_distances.append(distance_fn(cp.asarray(batch_points[k]), cp.asarray(latents[k])) ** 2)
        
        pairwise_distances = cp.sqrt(sum(squared_distances))
        pairwise_distances[cp.arange(j - i), cp.arange(j - i)] = float("inf")
        neighbor_indices = cp.argsort(pairwise_distances, axis=-1)[:, :neighbors].get()

        neighbor_labels = labels[neighbor_indices]
        predicted_labels = stats.mode(neighbor_labels, axis=-1).mode
        c = (predicted_labels == batch_labels).sum()

        correct += c
        total += len(batch_labels)

    return correct / total


metadata_path = "/home/romainlhardy/code/hyperbolic-cancer/data/lung/metadata.tsv"
column_name = "mp_assignment"
if metadata_path is not None:
    labels = pd.read_csv(metadata_path, sep="\t")[column_name].replace(np.nan, "Unknown").values
else:
    labels = np.ones((len(dataset),)) # Dummy labels
    
unique_labels = np.unique(labels)
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
labels = np.array([label_to_idx[label] for label in labels])

knn_accuracy(latents, labels, get_distance_fns(model), neighbors=15)