In [None]:
import os
import sys
sys.path.append("/home/roh3635/hyperbolic-cancer/PoincareMaps")

import torch

from PoincareMaps.data import prepare_data, compute_rfa
from PoincareMaps.model import PoincareEmbedding, PoincareDistance
from PoincareMaps.rsgd import RiemannianSGD
from PoincareMaps.train import train
from torch.utils.data import TensorDataset, DataLoader

In [18]:
dset = "ToggleSwitch"
root = "root"

features, labels = prepare_data(f"/home/roh3635/hyperbolic-cancer/PoincareMaps/datasets/{dset}", with_labels=True, normalize=False, n_pca=0)

In [None]:
rfa = compute_rfa(
    features,
    mode="features", 
    k_neighbours=15, 
    distlocal="minkowski", 
    distfn="MFIsym", 
    connected=True, 
    sigma=1.0
) # Pairwise distances in the original data space

In [20]:
device = "cpu"
indices = torch.arange(len(rfa))

indices = indices.to(device)
rfa = rfa.to(device)

dataset = TensorDataset(indices, rfa)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [21]:
predictor = PoincareEmbedding(
    len(dataset), 
    2,
    dist=PoincareDistance,
    max_norm=1,
    Qdist="laplace", 
    lossfn="klSym",
    gamma=1.0,
    cuda=0
)
# predictor.size = dataset size
# predictor.lt = embedding matrix (inputs are indices)
# predictor.dist = distance function
# predictor.lossfn = loss function
# predictor.Qdist = distribution of the Poincaré ball
# predictor.gamma = gamma (temperature)

In [None]:
batch = next(iter(dataloader))
inputs, targets = batch
outputs = predictor(inputs) # [batch_size, len(dataset)]

assert outputs.sum(dim=-1).allclose(torch.ones(len(batch[0])))

predictor.lossfn(outputs, targets) # Try to match the distance distributions in the data space and the embedding space

In [23]:
optimizer = RiemannianSGD(predictor.parameters(), lr=0.1)

In [None]:
class PoincareOptions:
    def __init__(self, debugplot=False, epochs=500, batchsize=-1, lr=0.1, burnin=500, lrm=1.0, earlystop=0.0001, cuda=0):
        self.debugplot = debugplot
        self.batchsize = batchsize
        self.epochs = epochs
        self.lr =lr
        self.lrm =lrm
        self.burnin = burnin
        self.debugplot = debugplot

opt = PoincareOptions()
opt.batchsize = 16
embeddings, loss, epoch = train(
    predictor,
    dataset,
    optimizer,
    opt,
    fout=f"/home/roh3635/hyperbolic-cancer/data/outputs/test.csv",
    labels=labels,
    earlystop=0.0001,
    color_dict=None
)