In [None]:
import anndata as ad
import cupy as cp
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import scanpy as sc
import src.mvae.mt.mvae.utils as utils
import torch
import umap
import umap.umap_ as umap_
import yaml

from functools import partial
from scipy import stats
from scipy.io import mmread
from sklearn.decomposition import PCA
from src.lightning.gene import GeneModule
from src.mvae.mt.data import GeneDataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [2]:
save_path = "/home/romainlhardy/code/hyperbolic-cancer/models/umap/lung_umap10.h5ad"
key = "X_pca_harmony"
latents = [sc.read_h5ad(save_path).obsm[key]]

In [3]:
def pairwise_distances_euclidean(src: cp.ndarray, dst: cp.ndarray) -> cp.ndarray:
    assert src.ndim == 2 and dst.ndim == 2 and src.shape[1] == dst.shape[1]
    diff = src[:, None, :] - dst[None, :, :]
    return cp.linalg.norm(diff, axis=-1)


def get_distance_fns():
    return [pairwise_distances_euclidean]

In [None]:
def knn_accuracy(latents, labels, distance_fns, neighbors=15, batch_size=16):
    num_samples = len(latents[0])
    assert num_samples == len(labels)
    assert len(latents) == len(distance_fns)

    correct = 0
    total = 0

    for i in tqdm(range(0, num_samples, batch_size), total=(num_samples + batch_size - 1) // batch_size):
        j = min(i + batch_size, num_samples)
        batch_points = [l[i : j] for l in latents]
        batch_labels = labels[i : j]

        squared_distances = []
        for k, distance_fn in enumerate(distance_fns):
            squared_distances.append(distance_fn(cp.asarray(batch_points[k]), cp.asarray(latents[k])) ** 2)
        
        pairwise_distances = cp.sqrt(sum(squared_distances))
        pairwise_distances[cp.arange(j - i), cp.arange(j - i)] = float("inf")
        neighbor_indices = cp.argsort(pairwise_distances, axis=-1)[:, :neighbors].get()

        neighbor_labels = labels[neighbor_indices]
        predicted_labels = stats.mode(neighbor_labels, axis=-1).mode
        c = (predicted_labels == batch_labels).sum()

        correct += c
        total += len(batch_labels)

    return correct / total


metadata_path = "/home/romainlhardy/code/hyperbolic-cancer/data/lung/metadata.tsv"
column_name = "mp_assignment"
if metadata_path is not None:
    labels = pd.read_csv(metadata_path, sep="\t")[column_name].replace(np.nan, "Unknown").values
else:
    labels = np.ones((len(latents[0]),)) # Dummy labels
    
unique_labels = np.unique(labels)
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
labels = np.array([label_to_idx[label] for label in labels])

knn_accuracy(latents, labels, get_distance_fns(), neighbors=15)