# Notebook of comparison VAE and geometric VAE

In [1]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE
from scCFM.models.base.geodesic_ae import GeodesicAE
from scCFM.datamodules.sc_datamodule import scDataModule

from scCFM.models.utils import jacobian_decoder_jvp_parallel
from scCFM.models.manifold.geometry_metrics import compute_all_metrics

import sys
sys.path.insert(0, "../..")
from notebooks.utils import real_reconstructed_cells_adata

from paths import PROJECT_FOLDER
import anndata
import torch

  self.seed = seed
  self.dl_pin_memory_gpu_training = (
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Initialize the two different configurations 

General module hparams

In [2]:
datamodule={'path': PROJECT_FOLDER / 'data' / 'schiebinger_et_al' / 'processed' / 'schiebinger_et_al.h5ad', 
            'x_layer': 'X_norm', 
            'cond_keys': ['experimental_time', 'cell_sets'],
            'use_pca': False, 
            'n_dimensions': None, 
            'train_val_test_split': [1], 
            'batch_size': 32, 
            'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule)

### Model config VAE

In [3]:
model_vae={
       'in_dim': datamodule.in_dim,
       'n_epochs_anneal_kl': 1000, 
       'kl_weight': None, 
       'likelihood': 'nb', 
       'dropout': False, 
       'learning_rate': 0.001, 
       'dropout_p': False, 
       'model_library_size': True, 
       'batch_norm': True, 
       'kl_warmup_fraction': 0.1, 
       'hidden_dims': [256, 10]}
        
geometric={'compute_metrics_every': 1, 
           'use_c': True, 
           'l2': True, 
           'eta_interp': 0, 
           'interpolate_z': False, 
           'start_jac_after': 0, 
           'fl_weight': 0.1,
           'detach_theta': True}

geodesic={"in_dim": datamodule.in_dim,
          "hidden_dims": [256, 10],
          "batch_norm": True,
          "dropout": False, 
          "dropout_p": False,
          "likelihood": "nb",
          "learning_rate": 0.001}

In [4]:
vae = GeometricNBVAE(**geometric, vae_kwargs=model_vae)
geometric_vae = GeometricNBVAE(**geometric, vae_kwargs=model_vae)
geodesic_ae = GeodesicAE(**geodesic)

In [5]:
vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/schiebinger_et_al/best_model_vae.ckpt")["state_dict"])
geometric_vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/schiebinger_et_al/best_model_geometric.ckpt")["state_dict"])
geodesic_ae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/schiebinger_et_al/best_model_geodesic.ckpt")["state_dict"])

<All keys matched successfully>

In [6]:
vae = vae.eval()
geometric_vae = geometric_vae.eval()
geodesic_ae = geodesic_ae.eval()

In [7]:
from pytorch_lightning import Trainer

trainer_vae = Trainer(inference_mode=False)
trainer_geometric = Trainer(inference_mode=False)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
trainer_vae.validate(model=vae, dataloaders=datamodule.train_dataloader())

  rank_zero_warn(
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(


Validation: 0it [00:00, ?it/s]



[{'val/loss': 5.20633295454448e+24,
  'val/kl': 21.81983757019043,
  'val/lik': 508.5929870605469,
  'val/fl_loss': 5.2063251146782485e+25,
  'val/norm': 22.281217575073242,
  'reg_weight': 0.0,
  'fl_weight': 0.10000395029783249,
  'condition_number': 542.662353515625,
  'variance': 120.16801452636719,
  'magnification_factor': inf,
  'eu_kl_dist': 754.1693725585938,
  'eu_kl_corr': 0.5510973930358887}]

In [9]:
trainer_vae.validate(model=geometric_vae, dataloaders=datamodule.train_dataloader())

  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(


Validation: 0it [00:00, ?it/s]

[{'val/loss': 22499238.0,
  'val/kl': 132.962890625,
  'val/lik': 554.2285766601562,
  'val/fl_loss': 224986656.0,
  'val/norm': 87.03494262695312,
  'reg_weight': 0.0,
  'fl_weight': 0.10000395029783249,
  'condition_number': 13211.5400390625,
  'variance': 21.773656845092773,
  'magnification_factor': 23793969152.0,
  'eu_kl_dist': 356.65728759765625,
  'eu_kl_corr': 0.34404534101486206}]