In [None]:
import anndata
import numpy as np
import scvelo as scv
import sys
import torch

sys.path.append('../')
import velovae as vv

In [None]:
#filename = '/nfs/turbo/umms-welchjd/yichen/data/scRNA/3423-MV-2_adata_postpro.h5ad'
filename = '/nfs/turbo/umms-welchjd/yichen/data/scRNA/braindev_pp.h5ad'
adata = anndata.read_h5ad(filename)
#adata2 = anndata.read_h5ad('/scratch/blaauw_root/blaauw1/gyichen/output.h5ad')

In [None]:
Ngene = 1000
vv.preprocess(adata, Ngene)

In [None]:
figure_path = '../figures/Braindev/Rho'
model_path = '../checkpoints/Braindev/Rho'
data_path = '/scratch/blaauw_root/blaauw1/gyichen/'
#data_path = '../data/Dentategyrus'
#adata.obs['clusters'] = adata.obs['leiden'].to_numpy()
#adata.var.keys()
#np.unique(adata.obs['phase'].to_numpy())

In [None]:
scv.pl.umap(adata, save='blood_umap.png')

In [None]:
Cz = 5
model = vv.VanillaVAEpp(adata, 20, Cz, hidden_size=(500,250,250,500), tprior=None, device='gpu')

In [None]:
config_vae = {'num_epochs':500, 'test_epoch':50, 'save_epoch':50, 
                  'learning_rate':2e-4, 'learning_rate_ode':2e-4, 'lambda':1e-3, 
                  'neg_slope':0, 'reg_t':1.0, 'reg_z':1.0, 'batch_size':256,
    }

In [None]:
def sampleGenes(adata, n):
    variable = adata.var['highly_variable'].to_numpy()
    total = adata.var['mean_counts'].to_numpy()
    idx_sort = np.flip(np.argsort(total))
    genes = adata.var_names.to_numpy()
    #sort genes based total counts
    variable = variable[idx_sort]
    genes = genes[idx_sort]
    
    genes = genes[variable]
    idx = np.random.choice(np.sum(variable), n, replace=False)
    return genes[idx]

In [None]:
gene_plot = ['Auts2', 'Dync1i1', 'Gm3764', 'Mapt', 'Nfib', 'Rbfox1', 'Satb2', 'Slc6a13', 'Srrm4', 'Tcf4']
#gene_plot = ['Pcsk2','Dcdc2a','Gng12','Cpe','Smoc1','Tmem163','Ank', 'Ppp3ca']
#gene_plot = ['Ppp3ca','Ak5','Btbd9','Tmsb10','Hn1','Dlg2','Tcea1','Herc2']
#gene_plot = sampleGenes(adata, 8)
print(gene_plot)
model.train(adata, config=config_vae, gene_plot=gene_plot, figure_path=figure_path)

In [None]:
model.saveModel(model_path, 'encoder_vanillapp', 'decoder_vanillapp')
model.saveAnnData(adata, 'vanillapp', data_path, file_name='output_vanillapp.h5ad')

In [None]:
from sklearn.decomposition import PCA
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'lime', 'grey', \
   'olive', 'cyan', 'pink', 'gold', 'teal', 'steelblue', 'salmon',  \
   'magenta', 'rosybrown', 'darkorange', 'yellow', 'greenyellow', 'darkseagreen', 'yellowgreen', 'palegreen', \
   'hotpink', 'navajowhite', 'aqua', 'navy', 'saddlebrown', 'maroon',  'black']

In [None]:
filename = '../data/Dentategyrus/output_vanillapp.h5ad'
adata = anndata.read_h5ad(filename)

In [None]:
Cz = 5
model = vv.VanillaVAEpp(adata, 20, Cz, hidden_size=(500,250,250,500), tprior=None, device='gpu')

In [None]:
checkpoint1 = '../checkpoints/Dentategyrus/VanillaVAEpp/encoder_vanillapp.pt'
model.encoder.load_state_dict(torch.load(checkpoint1,map_location=model.device))
checkpoint2 = '../checkpoints/Dentategyrus/VanillaVAEpp/decoder_vanillapp.pt'
model.decoder.load_state_dict(torch.load(checkpoint2,map_location=model.device))

In [None]:
z = adata.obsm['vanillapp_z']
z_ts = torch.tensor(z).to(model.device)
rho = F.sigmoid(model.decoder.fc_out1(model.decoder.net(z_ts)))

In [None]:
rho = rho.detach().cpu().numpy()
rho

In [None]:
pca = PCA(n_components=3)
rho_pca = pca.fit_transform(rho)
rho_pca

In [None]:
cell_labels = adata.obs['clusters'].to_numpy()
cell_types = np.unique(cell_labels)
fig=plt.figure(figsize=(10,10))
ax = fig.add_subplot(projection='3d')
ax.view_init(30, 30)
for i,x in enumerate((cell_types)):
    ax.scatter(rho_pca[cell_labels==x,0], rho_pca[cell_labels==x,1], rho_pca[cell_labels==x,2], label=x, color=colors[i])
    
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')
plt.legend(bbox_to_anchor=(-0.15,1.0), loc='upper right')
plt.show()
fig.savefig(figure_path+'/rho.png')

In [None]:
import umap
umap_obj = umap.UMAP(n_neighbors=30, n_components=2, min_dist=0.25)
z_umap = umap_obj.fit_transform(z)
cell_labels = adata.obs['clusters'].to_numpy()
cell_types = np.unique(cell_labels)
fig=plt.figure(figsize=(10,10))
for i,x in enumerate((cell_types)):
    plt.scatter(z_umap[cell_labels==x,0], z_umap[cell_labels==x,1], label=x, color=colors[i])
plt.legend(bbox_to_anchor=(-0.15,1.0), loc='upper right')
plt.show()
fig.savefig(figure_path+'/z.png')

In [None]:
umap_obj = umap.UMAP(n_neighbors=30, n_components=2, min_dist=0.25)
rho_umap = umap_obj.fit_transform(rho)
fig=plt.figure(figsize=(10,10))
for i,x in enumerate((cell_types)):
    plt.scatter(rho_umap[cell_labels==x,0], rho_umap[cell_labels==x,1], label=x, color=colors[i])
plt.legend(bbox_to_anchor=(-0.15,1.0), loc='upper right')
plt.show()

In [None]:
adata.obsm['X_z'] = z_umap
#adata2.obsm['X_z'] = z_umap
adata.obsm['X_rho'] = rho_umap
#adata2.obsm['X_rho'] = rho_umap

In [None]:
gidx = np.where(adata.var_names=='Cpe')
rho_g = rho[:, gidx]
ton = adata.var['vanillapp_ton'].to_numpy()
t = adata.obs['vanillapp_time'].to_numpy()
print(ton[gidx])
for i,x in enumerate(np.flip(cell_types)):
    tmask = t[cell_labels==x]>ton[gidx]
    if(np.any(tmask)):
        print(x,': ', rho_g[(cell_labels==x)&(t>=ton[gidx])].mean(), rho_g[(cell_labels==x)&(t>=ton[gidx])].std())

In [None]:
def predSteadyNumpy(ts,alpha,beta,gamma):
    """
    (Numpy Version)
    Predict the steady states.
    ts: [G] switching time, when the kinetics enters the repression phase
    alpha, beta, gamma: [G] generation, splicing and degradation rates
    """
    alpha_, beta_, gamma_ = np.clip(alpha,a_min=0,a_max=None), np.clip(beta,a_min=0,a_max=None), np.clip(gamma,a_min=0,a_max=None)
    eps = 1e-6
    unstability = np.abs(beta-gamma) < 1e-3
    
    ts_ = ts.squeeze()
    expb, expg = np.exp(-beta*ts_), np.exp(-gamma*ts_)
    u0 = alpha/(beta+eps)*(1.0-expb)
    s0 = alpha/(gamma+eps)*(1.0-expg)+alpha/(gamma-beta+eps)*(expg-expb)*(1-unstability)+alpha*ts_*expg*unstability
    return u0,s0
def odeNumpy(t,alpha,beta,gamma,to,ts,scaling=None):
    """
    (Numpy Version)
    ODE Solution
    
    t: [B x 1] cell time
    alpha, beta, gamma: [G] generation, splicing and degradation rates
    to, ts: [G] switch-on and -off time
    """
    unstability = (np.abs(beta - gamma) < 1e-3)
    eps = 1e-6
    
    o = (t<=ts).astype(int)
    #Induction
    tau_on = np.clip(t-to,a_min=0,a_max=None)
    expb, expg = np.exp(-beta*tau_on), np.exp(-gamma*tau_on)
    uhat_on = alpha/(beta+eps)*(1.0-expb)
    shat_on = alpha/(gamma+eps)*(1.0-expg)+alpha/(gamma-beta+eps)*(expg-expb)*(1-unstability)+alpha*tau_on*unstability
    
    #Repression
    u0_,s0_ = predSteadyNumpy(ts-to,alpha,beta,gamma) #[G]
    if(ts.ndim==2 and to.ndim==2):
        u0_ = u0_.reshape(-1,1)
        s0_ = s0_.reshape(-1,1)
    tau_off = np.clip(t-ts,a_min=0,a_max=None)
    expb, expg = np.exp(-beta*tau_off), np.exp(-gamma*tau_off)
    uhat_off = u0_*expb
    shat_off = s0_*expg+(-beta*u0_)/(gamma-beta+eps)*(expg-expb)*(1-unstability)
    
    uhat, shat = (uhat_on*o + uhat_off*(1-o)),(shat_on*o + shat_off*(1-o))
    if(scaling is not None):
        uhat *= scaling
    return uhat, shat
def rnaVelocity(adata, key, rho, use_raw=False, use_scv_genes=False):
    """
    Compute the velocity based on:
    ds/dt = beta * u - gamma * s
    """
    alpha = adata.var[f"{key}_alpha"].to_numpy()
    beta = adata.var[f"{key}_beta"].to_numpy()
    gamma = adata.var[f"{key}_gamma"].to_numpy()
    t = adata.obs[f"{key}_time"].to_numpy()
    ton = adata.var[f"{key}_ton"].to_numpy()
    toff = adata.var[f"{key}_t_"].to_numpy()
    scaling = adata.var[f"{key}_scaling"].to_numpy()
    if(use_raw):
        U, S = adata.layers['Mu'], adata.layers['Ms']
    else:
        U, S = odeNumpy(t.reshape(-1,1),alpha * rho,beta,gamma,ton,toff, None) #don't need scaling here
        adata.layers["Uhat"] = U
        adata.layers["Shat"] = S
    
    V = (beta * U - gamma * S)*(t.reshape(-1,1) >= ton)
    adata.layers[f"{key}_velocity"] = V
    if(use_scv_genes):
        gene_mask = np.isnan(adata.var['fit_scaling'].to_numpy())
        V[:, gene_mask] = np.nan
    return V, U, S

In [None]:
key='vanillapp'
figure_path = '../figures/Blood'
V, U, S = rnaVelocity(adata, 'vanillapp', rho, use_raw=False, use_scv_genes=False)
scv.tl.velocity_graph(adata, vkey=f'{key}_velocity', basis='umap', n_jobs=2)
scv.tl.velocity_embedding(adata, vkey=f'{key}_velocity', basis='umap')
scv.pl.velocity_embedding_stream(adata, vkey=f'{key}_velocity', basis='umap', figsize=(8,6), save=figure_path+f'/blood_{key}velz.png')