In [1]:
from paths import DATA_DIR, CKPT_FOLDER, PROJECT_FOLDER

import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import anndata
import scanpy as sc
import sklearn
import scvelo as scv
from pathlib import Path
import seaborn as sns
import pertpy as pt

import anndata
import pandas as pd

from IPython.display import display
from torchdyn.core import NeuralODE

from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.cfm.components.mlp import MLP
from scCFM.models.cfm.cfm_module import CFMLitModule

from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE
from scCFM.models.base.geodesic_ae import GeodesicAE

from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.models.cfm.components.eval.distribution_distances import compute_distribution_distances

from notebooks.utils import decode_trajectory_single_step, standardize, compute_prdc

from tqdm import tqdm
import time
import pandas as pd
import warnings

Initialize datamodule

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

datamodule_config={'path': PROJECT_FOLDER / 'data/eb/processed/eb_phate.h5ad', 
                    'x_layer': 'X_norm', 
                    'cond_keys': ['experimental_time', 'leiden'],
                    'use_pca': False, 
                    'n_dimensions': None, 
                    'train_val_test_split': [1], 
                    'batch_size': 256, 
                    'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule_config)

Initialize configurations

In [3]:
vae_kwargs={'in_dim': datamodule.in_dim,
               'n_epochs_anneal_kl': 1000, 
               'kl_weight': None, 
               'likelihood': 'nb', 
               'dropout': False, 
               'learning_rate': 0.001, 
               'dropout_p': False, 
               'model_library_size': True, 
               'batch_norm': True, 
               'kl_warmup_fraction': 0.1, 
               'hidden_dims': [256, 10]}
        
geometric_kwargs={'compute_metrics_every': 1, 
                   'use_c': False, 
                   'trainable_c': False,
                   'l2': True, 
                   'eta_interp': 0, 
                   'interpolate_z': False, 
                   'start_jac_after': 0, 
                   'fl_weight': 0.1,
                   'detach_theta': True}

geodesic_kwargs={"in_dim": datamodule.in_dim,
                  "hidden_dims": [256, 10],
                  "batch_norm": True,
                  "dropout": False, 
                  "dropout_p": False,
                  "likelihood": "nb",
                  "learning_rate": 0.001}

In [4]:
# Suppress warnings
warnings.filterwarnings("ignore")

# Define batch sizes to test
batch_sizes = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
num_examples = 10  # Number of examples to run for each batch size

## Without backpropagation

In [5]:
# Initialize empty lists to store results
results = {"model": [],
           "batch_size": [], 
           "runtime": []
          }

# Iterate over batch sizes
for batch_size in batch_sizes:
    # Repeat for each batch size
    for _ in range(num_examples):
        # Set batch size in datamodule
        datamodule_config["batch_size"] = batch_size
        datamodule = scDataModule(**datamodule_config)
        batch = next(iter(datamodule.train_dataloader()))
        batch["X"] = batch["X"].cuda()
        
        # Initialize vae and geometric vae
        vae = VAE(**vae_kwargs).to(device)
        geometric_vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
        geodesic_ae = GeodesicAE(**geodesic_kwargs).to(device)

        # Time forward step for geodesic_ae
        start_time = time.time()
        geodesic_ae.step(batch, "train")
        end_time = time.time()
        geodesic_ae_runtime = end_time - start_time
        results["model"].append("geodesic")
        results["batch_size"].append(batch_size)
        results["runtime"].append(geodesic_ae_runtime)

        # Time forward step for geometric_vae
        start_time = time.time()
        geometric_vae.step(batch, "train")
        end_time = time.time()
        geometric_ae_runtime = end_time - start_time
        results["model"].append("geometric")
        results["batch_size"].append(batch_size)
        results["runtime"].append(geometric_ae_runtime)

        # Time forward step for vae
        start_time = time.time()
        vae.step(batch, "train")
        end_time = time.time()
        vae_runtime = end_time - start_time
        results["model"].append("vae")
        results["batch_size"].append(batch_size)
        results["runtime"].append(vae_runtime)

        del batch 
        del vae
        del geometric_vae
        del geodesic_ae
        
# Create dataframe
df = pd.DataFrame(results)

In [6]:
df.groupby(["model","batch_size"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,runtime
model,batch_size,Unnamed: 2_level_1
geodesic,4,1.831309
geodesic,8,0.11864
geodesic,16,0.094533
geodesic,32,0.114758
geodesic,64,0.098721
geodesic,128,0.116311
geodesic,256,0.217593
geodesic,512,0.512575
geodesic,1024,1.521877
geodesic,2048,7.01141


In [7]:
df.groupby(["model","batch_size"]).std() / np.sqrt(10)

Unnamed: 0_level_0,Unnamed: 1_level_0,runtime
model,batch_size,Unnamed: 2_level_1
geodesic,4,1.697728
geodesic,8,0.011796
geodesic,16,0.000549
geodesic,32,0.01942
geodesic,64,0.002376
geodesic,128,0.006646
geodesic,256,0.017366
geodesic,512,0.020542
geodesic,1024,0.013141
geodesic,2048,0.019636


## With backpropagation

In [8]:
import warnings
import pandas as pd
import time

# Suppress warnings
warnings.filterwarnings("ignore")

# Initialize empty lists to store results
results_backprop = {"model": [],
                    "batch_size": [],
                    "runtime": []}


# Iterate over batch sizes
for batch_size in batch_sizes:
    # Repeat for each batch size
    for _ in range(num_examples):
        # Set batch size in datamodule
        datamodule_config["batch_size"] = batch_size
        datamodule = scDataModule(**datamodule_config)
        batch = next(iter(datamodule.train_dataloader()))
        batch["X"] = batch["X"].cuda()
        
        # Initialize vae and geometric vae
        vae = VAE(**vae_kwargs).to(device)
        geometric_vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
        geodesic_ae = GeodesicAE(**geodesic_kwargs).to(device)

        # Initialize optimizers for each model
        vae_optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
        geometric_vae_optimizer = torch.optim.Adam(geometric_vae.parameters(), lr=0.001)
        geodesic_ae_optimizer = torch.optim.Adam(geodesic_ae.parameters(), lr=0.001)

        # Time forward step and backward step for geodesic_ae
        start_time = time.time()
        loss = geodesic_ae.step(batch, "train")
        geodesic_ae_optimizer.zero_grad()  # Clear gradients before backward pass
        loss.backward()
        geodesic_ae_optimizer.step()  # Update parameters based on gradients
        end_time = time.time()
        geodesic_ae_runtime = end_time - start_time
        results_backprop["model"].append("geodesic")
        results_backprop["batch_size"].append(batch_size)
        results_backprop["runtime"].append(geodesic_ae_runtime)

        # Time forward step and backward step for geometric_vae
        start_time = time.time()
        loss = geometric_vae.step(batch, "train")
        geometric_vae_optimizer.zero_grad()  # Clear gradients before backward pass
        loss.backward()
        geometric_vae_optimizer.step()  # Update parameters based on gradients
        end_time = time.time()
        geometric_ae_runtime = end_time - start_time
        results_backprop["model"].append("geometric")
        results_backprop["batch_size"].append(batch_size)
        results_backprop["runtime"].append(geometric_ae_runtime)

        # Time forward step and backward step for vae
        start_time = time.time()
        loss = vae.step(batch, "train")
        vae_optimizer.zero_grad()  # Clear gradients before backward pass
        loss.backward()
        vae_optimizer.step()  # Update parameters based on gradients
        end_time = time.time()
        vae_runtime = end_time - start_time
        results_backprop["model"].append("vae")
        results_backprop["batch_size"].append(batch_size)
        results_backprop["runtime"].append(vae_runtime)

        del batch 
        del vae
        del geometric_vae
        del geodesic_ae
        
# Create dataframe
df_backprop = pd.DataFrame(results_backprop)


In [9]:
df_backprop.groupby(["model","batch_size"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,runtime
model,batch_size,Unnamed: 2_level_1
geodesic,4,0.122397
geodesic,8,0.10628
geodesic,16,0.09779
geodesic,32,0.096211
geodesic,64,0.1176
geodesic,128,0.220128
geodesic,256,0.272736
geodesic,512,0.479076
geodesic,1024,1.608938
geodesic,2048,8.1816


In [10]:
df_backprop.groupby(["model","batch_size"]).std() / np.sqrt(10)

Unnamed: 0_level_0,Unnamed: 1_level_0,runtime
model,batch_size,Unnamed: 2_level_1
geodesic,4,0.017907
geodesic,8,0.007079
geodesic,16,0.003151
geodesic,32,0.001001
geodesic,64,0.008683
geodesic,128,0.020137
geodesic,256,0.022164
geodesic,512,0.011237
geodesic,1024,0.08145
geodesic,2048,0.310014


In [11]:
batch

NameError: name 'batch' is not defined