In [2]:
%load_ext autoreload
%autoreload 2

import scvi
import torch
import numpy as np
import pytorch_lightning as pl
import os

import scanpy as sc

from torchcfm.conditional_flow_matching import *
import scanpy as sc
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# load some data
adata = sc.read_h5ad('/orcd/archive/abugoot/001/Projects/dlesman/datasets/kaggle_HVG.h5ad')

# here we set up the train/eval and control/pert sets
# set the idx of the controls
control_idx = adata.obs['sm_name'] == 'Dimethyl Sulfoxide'
# set the idx of the perts (currently just "all not control")
pert_idx = adata.obs['sm_name'] != 'Dimethyl Sulfoxide'
# set the hold out cell-type/pert
eval_cell_idx = adata.obs.cell_type == 'B cells'
eval_pert_idx = adata.obs['sm_name'] == 'Belinostat'
eval_idx = eval_cell_idx & eval_pert_idx

In [4]:
# here we set up our embeddings for cfm
# this is just so everything lives in obsm for the for loop below
adata.obsm["X"] = adata.X

# this is an example of how we can embed something using just the train idxs
# and then run fm on that embedding
# embedder = PCA(n_components=30).fit(adata.X[(control_idx | pert_idx) & ~eval_idx])
# adata.obsm["X_pca"] = embedder.transform(adata.X)

In [5]:
# here we set up the perts
import pandas as pd
perts = pd.get_dummies(adata.obs['sm_name']).values.astype(float)
pert_ids = perts.argmax(axis=1)
# this is the "identity featurization"; we can swap this matrix for
# any latent representation of perturbations we want but this is 
# a non-parametric featurization for right now
pert_mat = np.eye(pert_ids.max() + 1).astype('float32')

In [6]:
cell_types = pd.get_dummies(adata.obs.cell_type).values.argmax(axis=1)

In [7]:
X = adata.obsm['X']

# set train and eval split
control_train = X[control_idx & ~eval_idx]
pert_train = X[pert_idx & ~eval_idx]
pert_ids_train =  pert_ids[pert_idx & ~eval_idx]
control_cell_types = cell_types[control_idx & ~eval_idx]
pert_cell_types = cell_types[pert_idx & ~eval_idx]

control_eval = X[control_idx & eval_cell_idx]
pert_eval = X[eval_idx]
pert_ids_eval = pert_ids[eval_idx]

batch_size = 32

In [11]:
from DSBM_Gaussian import * 

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="conf", config_name="gaussian.yaml")


In [38]:
from argparse import Namespace
cfg = Namespace(net_name="mlp_small", num_steps=20, sigma=1, inner_iters=10000, outer_iters=40,
               activation_fn=torch.nn.SiLU, model_name="dsb", first_coupling="ref",  fb_sequence=['b', 'f'])

In [39]:
x0 = torch.tensor(control_train)
x1 = torch.tensor(pert_train)
x1 = x0[:x1.shape[0]]
dim = x0.shape[1]
x_pairs = torch.stack([x0, x1], dim=1).to(device)

In [40]:
net_split = cfg.net_name.split("_")
if net_split[0] == "mlp":
    if net_split[1] == "small":
        net_fn = partial(ScoreNetwork, input_dim=dim+1, layer_widths=[128, 128, dim], activation_fn=cfg.activation_fn())    # hydra.utils.get_method(cfg.activation_fn))    # 
    else:
        net_fn = partial(ScoreNetwork, input_dim=dim+1, layer_widths=[256, 256, dim], activation_fn=cfg.activation_fn())    # hydra.utils.get_method(cfg.activation_fn))    # 
else:
    raise NotImplementedError

num_steps = cfg.num_steps
sigma = cfg.sigma
inner_iters = cfg.inner_iters
outer_iters = cfg.outer_iters

if cfg.model_name == "dsb":
    model = DSB(net_fwd=net_fn().to(device), 
                            net_bwd=net_fn().to(device), 
                            num_steps=num_steps, sig=sigma)
    train_fn = train_dsb_ipf
    print(f"Number of parameters: <{sum(p.numel() for p in model.net_fwd.parameters() if p.requires_grad)}>")
elif cfg.model_name == "dsbm":
    model = DSBM(net_fwd=net_fn().to(device), 
                                net_bwd=net_fn().to(device), 
                                num_steps=num_steps, sig=sigma, first_coupling=cfg.first_coupling)
    train_fn = train_dsbm
    print(f"Number of parameters: <{sum(p.numel() for p in model.net_fwd.parameters() if p.requires_grad)}>")
elif cfg.model_name == "sbcfm":
    model = SBCFM(net=net_fn().to(device), 
                                num_steps=num_steps, sig=sigma)
    train_fn = train_flow_model
    print(f"Number of parameters: <{sum(p.numel() for p in model.net.parameters() if p.requires_grad)}>")
elif cfg.model_name == "rectifiedflow":
    model = RectifiedFlow(net=net_fn().to(device), 
                                                num_steps=num_steps, sig=None)
    train_fn = train_flow_model
    print(f"Number of parameters: <{sum(p.numel() for p in model.net.parameters() if p.requires_grad)}>")
else:
    raise ValueError("Wrong model_name!")


# Training loop
# first_it = True
model_list = []
it = 1

# assert outer_iters % len(cfg.fb_sequence) == 0
while it <= outer_iters:
    for fb in cfg.fb_sequence:
        print(f"Iteration {it}/{outer_iters} {fb}")
        first_it = (it == 1)
        if first_it:
            prev_model = None
        else:
            prev_model = model_list[-1]["model"].eval()
        model, loss_curve = train_fn(model, x_pairs, batch_size, inner_iters, prev_model=prev_model, fb=fb, first_it=first_it)
        model_list.append({'fb': fb, 'model': copy.deepcopy(model).eval()})
    
        if False: # hasattr(model, "sample_sde"):
            draw_plot(partial(model.sample_sde, zstart=x_test_dict[fb], fb=fb, first_it=first_it), z0=x_test_dict['f'], z1=x_test_dict['b'])
            plt.savefig(f"{it}-{fb}.png")
            plt.close()

            # Evaluation
            optimal_result_dict = {'mean': -a, 'var': 1, 'cov': (np.sqrt(5) - 1) / 2}
            result_list = {k: [] for k in optimal_result_dict.keys()}
            for i in range(it):
                traj = model_list[i]['model'].sample_sde(zstart=x1_test, fb='b')
                result_list['mean'].append(traj[-1].mean(0).mean(0).item())
                result_list['var'].append(traj[-1].var(0).mean(0).item())
                result_list['cov'].append(torch.cov(torch.cat([traj[0], traj[-1]], dim=1).T)[dim:, :dim].diag().mean(0).item())
            for i, k in enumerate(result_list.keys()):
                plt.plot(result_list[k], label=f"{cfg.model_name}-{cfg.net_name}")
                plt.plot(np.arange(outer_iters), optimal_result_dict[k] * np.ones(outer_iters), label="optimal", linestyle="--")
                plt.title(k.capitalize())
                if i == 0:
                    plt.legend()
                plt.savefig(f"convergence_{k}.png")
                plt.close()
            
            result_list_100 = {k: [] for k in optimal_result_dict.keys()}
            for i in range(it):
                traj_100 = model_list[i]['model'].sample_sde(zstart=x1_test, fb='b', N=100)
                result_list_100['mean'].append(traj_100[-1].mean(0).mean(0).item())
                result_list_100['var'].append(traj_100[-1].var(0).mean(0).item())
                result_list_100['cov'].append(torch.cov(torch.cat([traj_100[0], traj_100[-1]], dim=1).T)[dim:, :dim].diag().mean(0).item())
        
        if False: # hasattr(model, "sample_ode"):
            draw_plot(partial(model.sample_ode, zstart=x_test_dict[fb], fb=fb, first_it=first_it), z0=x_test_dict['f'], z1=x_test_dict['b'])
            plt.savefig(f"{it}-{fb}-ode.png")
            plt.close()

            # Evaluation
            optimal_result_dict_ode = {'mean': -a, 'var': 1}
            result_list_ode = {k: [] for k in optimal_result_dict_ode.keys()}
            for i in range(it):
                traj_ode = model_list[i]['model'].sample_ode(zstart=x1_test, fb='b')
                result_list_ode['mean'].append(traj_ode[-1].mean(0).mean(0).item())
                result_list_ode['var'].append(traj_ode[-1].var(0).mean(0).item())
            for i, k in enumerate(result_list_ode.keys()):
                plt.plot(result_list_ode[k], label=f"{cfg.model_name}-{cfg.net_name}-ode")
                plt.plot(np.arange(outer_iters), optimal_result_dict_ode[k] * np.ones(outer_iters), label="optimal", linestyle="--")
                plt.title(k.capitalize())
                if i == 0:
                    plt.legend()
                plt.savefig(f"convergence_{k}-ode.png")
                plt.close()
            
            result_list_ode_100 = {k: [] for k in optimal_result_dict_ode.keys()}
            for i in range(it):
                traj_ode_100 = model_list[i]['model'].sample_ode(zstart=x1_test, fb='b', N=100)
                result_list_ode_100['mean'].append(traj_ode_100[-1].mean(0).mean(0).item())
                result_list_ode_100['var'].append(traj_ode_100[-1].var(0).mean(0).item())

        # first_it = False
        it += 1

        if it > outer_iters:
            break

torch.save([{'fb': m['fb'], 'model': m['model'].state_dict()} for m in model_list], "model_list.pt")

if hasattr(model, "sample_sde"):
    df_result = pd.DataFrame(result_list)
    df_result_100 = pd.DataFrame(result_list_100)
    df_result.to_csv('df_result.csv')
    df_result.to_pickle('df_result.pkl')
    df_result_100.to_csv('df_result_100.csv')
    df_result_100.to_pickle('df_result_100.pkl')

    # Trajectory
    np.save("traj.npy", torch.stack(traj, dim=1).detach().cpu().numpy())
    np.save("traj_100.npy", torch.stack(traj_100, dim=1).detach().cpu().numpy())

if hasattr(model, "sample_ode"):
    df_result_ode = pd.DataFrame(result_list_ode)
    df_result_ode_100 = pd.DataFrame(result_list_ode_100)
    df_result_ode.to_csv('df_result_ode.csv')
    df_result_ode.to_pickle('df_result_ode.pkl')
    df_result_ode_100.to_csv('df_result_ode_100.csv')
    df_result_ode_100.to_pickle('df_result_ode_100.pkl')

    # Trajectory
    np.save("traj_ode.npy", torch.stack(traj_ode, dim=1).detach().cpu().numpy())
    np.save("traj_ode_100.npy", torch.stack(traj_ode_100, dim=1).detach().cpu().numpy())

return {}, {}

Number of parameters: <530768>
Iteration 1/40 b


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [25:07<00:00,  6.63it/s]


Iteration 2/40 f


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [50:35<00:00,  3.29it/s]


Iteration 3/40 b


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [45:08<00:00,  3.69it/s]


Iteration 4/40 f


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [41:11<00:00,  4.05it/s]


Iteration 5/40 b


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [39:53<00:00,  4.18it/s]


Iteration 6/40 f


100%|█████████████████████████| 10000/10000 [38:55<00:00,  4.28it/s]


Iteration 7/40 b


100%|█████████████████████████| 10000/10000 [38:54<00:00,  4.28it/s]


Iteration 8/40 f


100%|█████████████████████████| 10000/10000 [38:57<00:00,  4.28it/s]


Iteration 9/40 b


100%|█████████████████████████| 10000/10000 [39:06<00:00,  4.26it/s]


Iteration 10/40 f


100%|█████████████████████████| 10000/10000 [39:17<00:00,  4.24it/s]


Iteration 11/40 b


100%|█████████████████████████| 10000/10000 [36:38<00:00,  4.55it/s]


Iteration 12/40 f


 13%|███▍                      | 1331/10000 [04:08<26:55,  5.37it/s]


ValueError: Loss is nan