# Notebook of comparison VAE and geometric VAE

In [1]:
import numpy as np
import pandas as pd
import scanpy as sc

from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE, GeometricNBAE
from scCFM.models.base.geodesic_ae import GeodesicAE

from scCFM.datamodules.sc_datamodule import scDataModule

from scCFM.models.manifold.geometry_metrics import compute_all_metrics

import sys
sys.path.insert(0, "../..")
from notebooks.utils import real_reconstructed_cells_adata

from paths import PROJECT_FOLDER
import torch
import anndata
import matplotlib.pyplot as plt

import scipy

import seaborn as sns

## Initialize the two different configurations 

General module hparams

In [2]:
datamodule={'path': PROJECT_FOLDER / 'data/eb/processed/eb_phate.h5ad', 
            'x_layer': 'X_norm', 
            'cond_keys': ['experimental_time', 'leiden'],
            'use_pca': False, 
            'n_dimensions': None, 
            'train_val_test_split': [1], 
            'batch_size': 1024, 
            'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule)

### Model config VAE

In [3]:
model_vae={
       'in_dim': datamodule.in_dim,
       'n_epochs_anneal_kl': 1000, 
       'kl_weight': None, 
       'likelihood': 'nb', 
       'dropout': False, 
       'learning_rate': 0.001, 
       'dropout_p': False, 
       'model_library_size': True, 
       'batch_norm': True, 
       'kl_warmup_fraction': 0.1, 
       'hidden_dims': [256, 10]}
        
geometric={'compute_metrics_every': 1, 
           'use_c': True,
           'trainable_c': False,
           'l2': True, 
           'eta_interp': 0, 
           'interpolate_z': False, 
           'start_jac_after': 0, 
           'fl_weight': 0.1,
           'detach_theta': True}

geodesic={"in_dim": datamodule.in_dim,
          "hidden_dims": [256, 10],
          "batch_norm": True,
          "dropout": False, 
          "dropout_p": False,
          "likelihood": "nb",
          "learning_rate": 0.001}

In [4]:
vae = GeometricNBVAE(**geometric, vae_kwargs=model_vae)
geometric_vae = GeometricNBVAE(**geometric, vae_kwargs=model_vae)
geodesic_ae = GeodesicAE(**geodesic)

In [5]:
# vae = GeometricNBAE(**geometric, vae_kwargs=model_ae)
# geometric_vae = GeometricNBAE(**geometric, vae_kwargs=model_ae)

In [6]:
vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_vae_lib.ckpt")["state_dict"])
geometric_vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_geometric_lib.ckpt")["state_dict"])
geodesic_ae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_geodesic_ae.ckpt")["state_dict"])

<All keys matched successfully>

In [7]:
from pytorch_lightning import Trainer

trainer_vae = Trainer(inference_mode=False)
trainer_geometric = Trainer(inference_mode=False)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
trainer_vae.test(model=vae, dataloaders=datamodule.train_dataloader())

  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

[{'test/loss': 2281802844602368.0,
  'test/kl': 19.320337295532227,
  'test/lik': 490.4476013183594,
  'test/fl_loss': 2.281803166724915e+16,
  'test/norm': 118.23875427246094,
  'reg_weight': 0.0,
  'fl_weight': 0.10000001639127731,
  'condition_number': 330.0871887207031,
  'variance': 74.83828735351562,
  'magnification_factor': inf,
  'eu_kl_dist': 281.7516174316406}]

In [9]:
trainer_geometric.test(model=geometric_vae, dataloaders=datamodule.train_dataloader())

  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

[{'test/loss': 2607809280.0,
  'test/kl': 97.1597900390625,
  'test/lik': 525.8251953125,
  'test/fl_loss': 26078087168.0,
  'test/norm': 399.027099609375,
  'reg_weight': 0.0,
  'fl_weight': 0.10000001639127731,
  'condition_number': 311556.5625,
  'variance': 48.91563415527344,
  'magnification_factor': 317.0611877441406,
  'eu_kl_dist': 194.96453857421875}]