In [1]:
import os
import sys
sys.path.append("/home/romainlhardy/code/hyperbolic-cancer/PoincareMaps")

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from matplotlib.colors import LinearSegmentedColormap, to_hex
from sklearn.decomposition import PCA
from PoincareMaps.data import prepare_data, compute_rfa
from PoincareMaps.model import PoincareEmbedding, PoincareDistance, poincare_root, poincare_translation
from PoincareMaps.rsgd import RiemannianSGD
from PoincareMaps.train import train
from PoincareMaps.visualize import plotPoincareDisc, plot_poincare_disc
from torch.utils.data import TensorDataset, DataLoader

In [2]:
data_dir = "/home/romainlhardy/data/hyperbolic-cancer/msk_impact/stad_tcga_pan_can_atlas_2018"
dset = "msk_impact_stomach"

In [None]:
# Mutation data
df_mutations = pd.read_csv(os.path.join(data_dir, "data_mutations.txt"), sep="\t")
samples = df_mutations["Tumor_Sample_Barcode"].unique()
gene_counts = df_mutations["Hugo_Symbol"].value_counts()

print(f"Mutations: {df_mutations.shape[0]}")
print(f"Samples: {len(samples)}")
print(f"Genes: {len(gene_counts)}")

In [None]:
# Clinical data
df_clinical_samples  = pd.read_csv(os.path.join(data_dir, "data_clinical_sample.txt"), sep="\t", comment="#")
df_clinical_patients = pd.read_csv(os.path.join(data_dir, "data_clinical_patient.txt"), sep="\t", comment="#")
print(f"Clinical samples: {df_clinical_samples.shape[0]}")
print(f"Clinical patients: {df_clinical_patients.shape[0]}")

In [None]:
# Gene expression data
df_expression = pd.read_csv(os.path.join(data_dir, "data_mrna_seq_v2_rsem_zscores_ref_all_samples.txt"), sep="\t")

genes = df_expression["Entrez_Gene_Id"].tolist()
samples = df_expression.columns[2:].tolist()
print(len(samples))

# Feature matrix
expression_matrix = df_expression.iloc[:, 2:].T.dropna(axis=1).values.astype(np.float32)
print(expression_matrix.shape)

# PCA
pca = PCA(n_components=20)
features = pca.fit_transform(expression_matrix)
features = torch.DoubleTensor(features)
print(features.shape)

In [None]:
# # Extract labels (cancer subtype)
# if "ONCOTREE_CODE" in df_clinical_samples.columns:
#     label_column = "ONCOTREE_CODE"
# else:
#     print("No suitable cancer type column found. Using sample IDs as placeholder labels.")
#     label_column = None

# labels = []
# for sample in samples:
#     if label_column:
#         sample_info = df_clinical_samples[df_clinical_samples["SAMPLE_ID"] == sample]
#         if len(sample_info) > 0:
#             label = sample_info[label_column].iloc[0]
#         else:
#             label = "Unknown"
#     else:
#         label = "Unknown"
#     labels.append(label)

# print(np.unique(labels))

# Extract labels (mutation counts)
mutation_counts = df_mutations["Tumor_Sample_Barcode"].value_counts().to_dict()
labels = []
for sample in samples:
    if sample not in mutation_counts:
        mutation_counts[sample] = 1 # Avoid log(0)

c = list(mutation_counts.values())
min_count = max(1, min(c))
max_count = max(c)

sorted_counts = sorted(mutation_counts.values())
num_bins = 10
bin_size = len(sorted_counts) // num_bins
bins = [sorted_counts[min(i * bin_size, len(sorted_counts) - 1)] for i in range(num_bins + 1)]
bins[0]  = min_count - 0.1 
bins[-1] = max_count + 0.1

# Map each sample to its bin
for sample in samples:
    count = mutation_counts.get(sample, 0)
    if count == 0:
        bin_label = 0  # Special bin for zero mutations
    else:
        bin_idx = np.digitize(count, bins)
        bin_label = bin_idx
    labels.append(bin_label)

labels = np.array(labels)
print(len(labels))

In [None]:
rfa = compute_rfa(
    features,
    mode="features", 
    k_neighbours=30, 
    distlocal="minkowski",
    distfn="MFIsym", 
    connected=True, 
    sigma=1.0
) # Pairwise distances in the original data space

In [None]:
device = "cuda"
indices = torch.arange(len(rfa))

indices = indices.to(device)
rfa = rfa.to(device)

dataset = TensorDataset(indices, rfa)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

print(f"Dataset size: {len(dataset)}")

In [9]:
predictor = PoincareEmbedding(
    len(dataset), 
    2,
    dist=PoincareDistance,
    max_norm=1,
    Qdist="laplace", 
    lossfn="klSym",
    gamma=2.0,
    cuda=0
).to(device)

In [None]:
batch = next(iter(dataloader))
inputs, targets = batch
outputs = predictor(inputs.to(device)) # [batch_size, len(dataset)]

assert outputs.sum(dim=-1).allclose(torch.ones(len(batch[0])).to(device))

predictor.lossfn(outputs, targets) # Try to match the distance distributions in the data space and the embedding space

In [11]:
optimizer = RiemannianSGD(predictor.parameters(), lr=0.1) # RiemannianSGD optimizer

In [None]:
class PoincareOptions:
    def __init__(self, debugplot=False, epochs=500, batchsize=-1, lr=0.1, burnin=500, lrm=1.0, earlystop=0.0001, cuda=0):
        self.debugplot = debugplot
        self.epochs = epochs
        self.batchsize = batchsize
        self.lr = lr
        self.burnin = burnin
        self.lrm = lrm
        self.earlystop = earlystop
        self.cuda = cuda

opt = PoincareOptions(epochs=10000, batchsize=16)
embeddings, loss, epoch = train(
    predictor,
    dataset,
    optimizer,
    opt,
    fout=f"/home/romainlhardy/code/hyperbolic-cancer/data/outputs/{dset}",
    labels=labels,
    earlystop=1e-6,
    color_dict=None
)

In [13]:
from PoincareMaps.visualize import plotPoincareDisc, plot_poincare_disc

root = 1
root_hat = poincare_root(root, labels, features)
embeddings_rotated = poincare_translation(-embeddings[root_hat, :], embeddings)

sorted_labels = sorted(np.unique(labels))
cmap = LinearSegmentedColormap.from_list("cool_to_hot", ["blue", "cyan", "green", "yellow", "orange", "red"])
colors = [cmap(i / (len(sorted_labels) - 1)) for i in range(len(sorted_labels))]
color_dict = {label: to_hex(color) for label, color in zip(sorted_labels, colors)}
color_dict = plotPoincareDisc(embeddings.T, labels, file_name=f"/home/romainlhardy/code/hyperbolic-cancer/data/outputs/{dset}_raw", color_dict=color_dict)

plot_poincare_disc(
    embeddings_rotated,
    labels=labels,
    file_name=f"/home/romainlhardy/code/hyperbolic-cancer/data/outputs/{dset}_rot", 
    coldict=color_dict,
    d1=9.5, 
    d2=9.0
)