In [None]:
import warnings
warnings.filterwarnings('ignore')
import dynamo as dyn
import anndata
import pandas as pd
import numpy as np
import scipy.sparse
from anndata import AnnData
from scipy.sparse import csr_matrix
import preprocess_patched

In [None]:
rpe1 = dyn.sample_data.scEU_seq_rpe1()
dyn.convert2float(rpe1, ['Cell_cycle_possition', 'Cell_cycle_relativePos'])

rpe1_kinetics = rpe1[rpe1.obs.exp_type=='Pulse', :]
rpe1_kinetics.obs['time'] = rpe1_kinetics.obs['time'].astype(str)
rpe1_kinetics.obs.loc[rpe1_kinetics.obs['time'] == 'dmso', 'time'] = -1
rpe1_kinetics.obs['time'] = rpe1_kinetics.obs['time'].astype(float)
rpe1_kinetics = rpe1_kinetics[rpe1_kinetics.obs.time != -1, :]
rpe1_kinetics.layers['new'], rpe1_kinetics.layers['total'] = rpe1_kinetics.layers['ul'] + rpe1_kinetics.layers['sl'], rpe1_kinetics.layers['su'] + rpe1_kinetics.layers['sl'] + rpe1_kinetics.layers['uu'] + rpe1_kinetics.layers['ul']
del rpe1_kinetics.layers['uu'], rpe1_kinetics.layers['ul'], rpe1_kinetics.layers['su'], rpe1_kinetics.layers['sl']

In [None]:
rpe1_genes = ['UNG', 'PCNA', 'PLK1', 'HPRT1']
rpe1_kinetics.obs.time  = rpe1_kinetics.obs.time.astype('float')
rpe1_kinetics.obs.time = rpe1_kinetics.obs.time/60 # convert minutes to hours

# Use a patched version of this dynamo method instead
# dyn.tl.recipe_kin_data(adata=rpe1_kinetics,
#                        keep_filtered_genes=True,
#                        keep_raw_layers=True,
#                        del_2nd_moments=False,
#                        tkey='time',
#                       )
preprocess_patched.recipe_kin_data(adata=rpe1_kinetics,
                       keep_filtered_genes=True,
                       keep_raw_layers=True,
                       del_2nd_moments=False,
                       tkey='time', 
                       kwargs = {}, pca_kwargs = {'n_pca_components' : 100}
                      )

In [None]:
print(rpe1_kinetics.obsm["X_pca"].shape)

In [None]:
def streamline(adata):
    dyn.tl.reduceDimension(adata, reduction_method='umap')
    dyn.tl.cell_velocities(adata, enforce=True, vkey='velocity_T', ekey='M_t', basis='RFP_GFP', method = "cosine")
    dyn.pl.streamline_plot(adata, color=['Cell_cycle_possition', 'Cell_cycle_relativePos'], basis='RFP_GFP')
    return adata
rpe1_kinetics.obsm['X_RFP_GFP'] = rpe1_kinetics.obs.loc[:, ['RFP_log10_corrected', 'GFP_log10_corrected']].values.astype('float')
streamline(rpe1_kinetics)

In [None]:
print(rpe1_kinetics.obsm["X_pca"].shape)

In [None]:
dyn.pl.streamline_plot(rpe1_kinetics, color=['cell_cycle_phase'], basis='pca')

In [None]:
dyn.tl.cell_velocities(
    rpe1_kinetics,
    enforce=True,
    vkey='velocity_T', ekey='M_t',
    method="cosine",
    neg_cells_trick = True,
    basis = "pca",
);

In [None]:
dyn.vf.VectorField(rpe1_kinetics, basis='RFP_GFP', M=50)
dyn.vf.VectorField(rpe1_kinetics, basis='pca', M=50)

In [None]:
rpe1_kinetics.write_h5ad("rpe1_kinetics.h5ad")

In [None]:
import torch
# data = {"x" : torch.tensor(rpe1_kinetics.obsm["X_RFP_GFP_SparseVFC"]).float(),
#         "v" : torch.tensor(rpe1_kinetics.obsm["velocity_RFP_GFP_SparseVFC"]).float()}
dim = 30
# dim = 100
data = {"x" : torch.tensor(rpe1_kinetics.obsm["X_pca_SparseVFC"][:, range(dim)]).float(),
        "v" : torch.tensor(rpe1_kinetics.obsm["velocity_pca_SparseVFC"][:, range(dim)]).float()}

In [None]:
import sklearn as sk
from sklearn import linear_model
_X = data['x']
_v = data['v']

# Fit linear field
lr = linear_model.RidgeCV()
lr.fit(_X, _v)
A, b = lr.coef_, lr.intercept_
mu = -torch.tensor(np.linalg.pinv(A.T @ A + 1e-3*np.eye(_X.shape[1])) @ A.T @ b, dtype = torch.float32)
A = torch.tensor(A, dtype = torch.float32)
_v_fit = (_X - mu.numpy()) @ A.T.numpy()

t = rpe1_kinetics.obs.Cell_cycle_relativePos
t_bin = np.digitize(t, np.histogram_bin_edges(t, 5)[:-1])-1

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize = (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(_X[:, 0], _X[:, 1], alpha = 0.25, c = t_bin)
plt.quiver(_X[:, 0], _X[:, 1], _v[:, 0], _v[:, 1],)
plt.scatter(mu[0:1], mu[1:2], marker = 'x', color = 'red')
plt.subplot(1, 2, 2)
plt.scatter(_X[:, 0], _X[:, 1], alpha = 0.25, c = t_bin)
plt.quiver(_X[:, 0], _X[:, 1], _v_fit[:, 0], _v_fit[:, 1])
plt.scatter(mu[0:1], mu[1:2], marker = 'x', color = 'red')
plt.show()

In [None]:
data['A'] = A
data['mu'] = mu
data['t_idx'] = t_bin
torch.save(data, f"data_cellcycle_pca_{dim}.pkl")