In [None]:
import velvet as vt

In [None]:
# general packages
import numpy as np
import pandas as pd
import torch
from scipy.sparse import issparse

# velocity packages
import scanpy as sc
import scvelo as scv
import anndata as ann

# plotting packages
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm, trange
from IPython.display import clear_output

# color palette object
from colors import colorpalette as colpal

In [None]:
import pickle 
import colorsys
import random

def generate_blue_shades(n, hue_range=(0.3, 0.9), saturation_range=(0.5, 1.0), value_range=(0.5, 1.0)):
    r2h = lambda r,g,b: f"#{r:02x}{g:02x}{b:02x}"

    blue_shades = []
    for _ in range(n):
        hue = random.uniform(*hue_range)
        saturation = random.uniform(*saturation_range)
        value = random.uniform(*value_range)
        r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)
        r, g, b = int(r * 255), int(g * 255), int(b * 255)
        hex_color = r2h(r, g, b)
        blue_shades.append(hex_color)
    return blue_shades

In [None]:
neural = sc.read_h5ad('../data/neural_data_0606.h5ad')

state_path = '../models/neural_vf_0606_model.pt'

vt.md.Velvet.setup_anndata(neural, x_layer='total', n_layer='new', knn_layer='knn_index')

model = vt.md.Velvet(
    neural,
    n_latent = 50,
    linear_decoder = True,
    neighborhood_space="latent_space",
    biophysical_model = "full",
    gamma_mode = "learned",
    labelling_time = 2.0,
)

model.setup_model()

trained_state = torch.load(state_path)['model_state_dict']

model.module.load_state_dict(trained_state)

In [None]:
model.module = model.module.to('cuda')

model.adata.obs['index'] = np.arange(model.adata.shape[0])

vt.sm.VelvetSDE.setup_anndata(
    model, 
    x_layer='total', 
    index_key='index'
)

markov = vt.sb.MarkovProcess(
    model,
    n_neighbors=10,
    use_space='latent_space',
    use_spline=True, 
    use_similarity=False
)

sde = vt.sb.SDE(
    model.module.n_latent,
    prior_vectorfield=model.module.vf,
    noise_scalar=0.1,
    device=model.device
)

sde_model = vt.sm.VelvetSDE(
    model,
    sde,
    markov,
)


sde_state_path = '../models/neural_sde_0606_model.pt'

trained_state_sde = torch.load(sde_state_path)['model_state_dict']

sde_model.module.load_state_dict(trained_state_sde)

# Demo plot

In [None]:
initial_cells = model.adata[[ca in ['Early_Neural','Neural'] for ca in model.adata.obs.cell_annotation]]
initial_cells = initial_cells[np.random.choice(initial_cells.shape[0], size=1000, replace=False)]

trajectories, cell_ids = sde_model.simulate(
    initial_cells=initial_cells,
    n_samples_per_cell=1,
    n_steps = 100,
    t_max = 80,
    dt = 1.0,
    latent_key='X_z',
    n_chunks=10
)

In [None]:
%%time
avg_center, labels, centers, index = vt.cl.cluster_trajectories(
    trajectories, 
    n_clusters=3, 
    final_steps=60,
    n_iterations=500
)
cluster_labels = labels.cpu().numpy()

In [None]:
cmap = {2:'#87CEEB',1:'green',0:'#FF0000'}
clmap = {2:'FP Trajectories',1:'V3 Trajectories',0:'MN Trajectories'}

z = model.adata.obsm['X_z']
z = torch.tensor(z, device=model.device)

pca = PCA()
z_pca = pca.fit_transform(z.detach().cpu().numpy())

t_pca = []
for traj in trajectories:
    t_pca.append(pca.transform(traj.detach().cpu().numpy()))

copy = model.adata.copy()
copy.obsm['X_vae'] = z_pca
copy.uns["velocity_params"] = {'embeddings':'vae'}
fig = plt.figure(figsize=(15,8), dpi=300)
ax1, ax2 = fig.subplots(1,2)

scv.pl.scatter(copy, basis='vae', color='timepoint', alpha=0.3, palette=colpal.timepoint,
              ax=ax1, size=1000, show=False, components="1,2",
               legend_loc=False, title="PC1 v. PC2", fontsize=28)
scv.pl.scatter(copy, basis='vae', color='timepoint', alpha=0.3,
              ax=ax2, size=1000, show=False, components="1,5", palette=colpal.timepoint,
               legend_loc='right margin', title="PC1 v. PC5", fontsize=28)
groups = []
for t, cl in zip(t_pca, cluster_labels):
    color = cmap[cl]
    if cl in groups:
        label=''
    else:
        label=clmap[cl]
        groups.append(cl)
    ax1.scatter(t[0,0],t[0,1], color='red', marker='x')
    ax2.scatter(t[0,0],t[0,4], color='red', marker='x')
    ax1.plot(t[:,0],t[:,1], color=color, alpha=.5, linewidth=2, label=label)
    ax2.plot(t[:,0],t[:,4], color=color, alpha=.5, linewidth=2, label=label)

    
plt.suptitle("", fontsize=30, y=1.02)
plt.legend(loc=(.8,.55),fontsize=18)
plt.tight_layout()
plt.savefig('../figures/3.0.neural_trajectories.png', dpi=300)
plt.show()

In [None]:
scv.pl.scatter(copy, basis='vae', color='timepoint', legend_loc='right margin', palette=colpal.cmap_cat1)

# simulation

In [None]:
model.adata.obs['cell_type'] = ['Progenitor' if a in ['p3','pMN','Neural','Early_Neural'] else 'Neuron' 
                                for a in model.adata.obs.cell_annotation]

In [None]:
initial_cells = model.adata[[tp in ['D4'] for tp in model.adata.obs.timepoint]]
initial_cells = initial_cells[initial_cells.obs.cell_type=='Progenitor']

### noisy1

In [None]:
trajectories, cell_ids = sde_model.simulate(
    initial_cells=initial_cells,
    n_samples_per_cell=1,
    n_steps = 100,
    t_max = 100,
    dt = 1.0,
    latent_key='X_z',
    n_chunks=10
)

In [None]:
n_clusters = 3

avg_center, labels, centers, index = vt.cl.cluster_trajectories(
    trajectories, 
    n_clusters=n_clusters, 
    final_steps=50,
    n_iterations=200
)

cluster_labels = labels.cpu().numpy()

In [None]:
for i in range(n_clusters):
    fig = plt.figure(figsize=(14,6), dpi=200)
    ax1, ax2, ax3, ax4 = fig.subplots(1,4)
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax1,
                          line_alpha=0.5, components=[0,1], show=False,
                          cell_color='cell_annotation', title=f"C{i}")

    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax2,
                          line_alpha=0.5, components=[0,2], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax3,
                          line_alpha=0.5, components=[0,3], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax4,
                          line_alpha=0.5, components=[0,4], 
                          cell_color='cell_annotation', title="")

In [None]:
cluster_map = {0:'FP',1:'MN',2:'V3'}

In [None]:
# this code will work if you have many clusters assigned to a particular cell type
traj_dict = {'MN':[],'V3':[],'FP':[]}

for ci in range(n_clusters):
    traj_dict[cluster_map[ci]].append(trajectories[labels==ci])
    
traj_dict = {key:torch.vstack(val) for key, val in traj_dict.items()}

with open('../data/F4_noisy0.1_trajectories.pickle', 'wb') as f:
    pickle.dump(traj_dict, f)

In [None]:
with open('../data/F4_noisy0.1_trajectories.pickle', 'rb') as f:
    data = pickle.load(f)
    
fig = plt.figure(figsize=(10,8), dpi=300)
ax1 = fig.subplots()

plot_trajectories = data['MN']

z = model.adata.obsm['X_z']
z = torch.tensor(z, device=model.device)
pca = PCA()
z_pca = pca.fit_transform(z.detach().cpu().numpy())

copy = model.adata.copy()
copy.obsm['X_vae'] = z_pca
copy.uns["velocity_params"] = {'embeddings':'vae'}

scv.pl.scatter(copy, basis='vae', color='cell_annotation', alpha=0.3,
              ax=ax1, size=1000, show=False, components="1,2", palette=colpal.celltype,
               legend_loc=False, title="", fontsize=16)


t_pca = []
for traj in plot_trajectories:
    t_pca.append(pca.transform(traj.detach().cpu().numpy()))


colors = generate_blue_shades(n=len(t_pca))
for t, c in zip(t_pca, colors):
    ax1.scatter(t[0,0],t[0,1], color='red', marker='x')
    ax1.plot(t[:,0],t[:,1], color=c, alpha=0.2, linewidth=2)
    
plt.tight_layout()

plt.savefig('../figures/3.1.noisy0.1_viz.png', bbox_inches='tight', transparent=True)

plt.show()

### NOISE 0.05

In [None]:
sde_model.module.sde.noise_scalar = 0.05

trajectories, cell_ids = sde_model.simulate(
    initial_cells=initial_cells,
    n_samples_per_cell=1,
    n_steps = 100,
    t_max = 100,
    dt = 1.0,
    latent_key='X_z',
    n_chunks=10
)

sde_model.module.sde.noise_scalar = 0.1

In [None]:
n_clusters = 3

avg_center, labels, centers, index = vt.cl.cluster_trajectories(
    trajectories, 
    n_clusters=n_clusters, 
    final_steps=50,
    n_iterations=200
)

cluster_labels = labels.cpu().numpy()

In [None]:
for i in range(n_clusters):
    fig = plt.figure(figsize=(14,6), dpi=200)
    ax1, ax2, ax3, ax4 = fig.subplots(1,4)
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax1,
                          line_alpha=0.5, components=[0,1], show=False,
                          cell_color='cell_annotation', title=f"C{i}")

    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax2,
                          line_alpha=0.5, components=[0,2], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax3,
                          line_alpha=0.5, components=[0,3], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax4,
                          line_alpha=0.5, components=[0,4], 
                          cell_color='cell_annotation', title="")

In [None]:
cluster_map = {0:'V3',1:'MN',2:'FP'}

traj_dict = {'MN':[],'V3':[],'FP':[]}

for ci in range(n_clusters):
    traj_dict[cluster_map[ci]].append(trajectories[labels==ci])
    
traj_dict = {key:torch.vstack(val) for key, val in traj_dict.items()}

with open('../data/F4_noisy0.05_trajectories.pickle', 'wb') as f:
    pickle.dump(traj_dict, f)

In [None]:
with open('../data/F4_noisy0.05_trajectories.pickle', 'rb') as f:
    data = pickle.load(f)
    
fig = plt.figure(figsize=(10,8), dpi=300)
ax1 = fig.subplots()

plot_trajectories = data['MN']

z = model.adata.obsm['X_z']
z = torch.tensor(z, device=model.device)
pca = PCA()
z_pca = pca.fit_transform(z.detach().cpu().numpy())

copy = model.adata.copy()
copy.obsm['X_vae'] = z_pca
copy.uns["velocity_params"] = {'embeddings':'vae'}

scv.pl.scatter(copy, basis='vae', color='cell_annotation', alpha=0.3,
              ax=ax1, size=1000, show=False, components="1,2", palette=colpal.celltype,
               legend_loc=False, title="", fontsize=16)


t_pca = []
for traj in plot_trajectories:
    t_pca.append(pca.transform(traj.detach().cpu().numpy()))


colors = generate_blue_shades(n=len(t_pca))
for t, c in zip(t_pca, colors):
    ax1.scatter(t[0,0],t[0,1], color='red', marker='x')
    ax1.plot(t[:,0],t[:,1], color=c, alpha=0.2, linewidth=2)
    
plt.tight_layout()

plt.savefig('../figures/3.0.2.noisy0.05_viz.png', bbox_inches='tight', transparent=True)

plt.show()

noise0.025

In [None]:
sde_model.module.sde.noise_scalar = 0.025

trajectories, cell_ids = sde_model.simulate(
    initial_cells=initial_cells,
    n_samples_per_cell=1,
    n_steps = 100,
    t_max = 100,
    dt = 1.0,
    latent_key='X_z',
    n_chunks=10
)

sde_model.module.sde.noise_scalar = 0.1

In [None]:
n_clusters = 3

avg_center, labels, centers, index = vt.cl.cluster_trajectories(
    trajectories, 
    n_clusters=n_clusters, 
    final_steps=50,
    n_iterations=200
)

cluster_labels = labels.cpu().numpy()

In [None]:
for i in range(n_clusters):
    fig = plt.figure(figsize=(14,6), dpi=200)
    ax1, ax2, ax3, ax4 = fig.subplots(1,4)
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax1,
                          line_alpha=0.5, components=[0,1], show=False,
                          cell_color='cell_annotation', title=f"C{i}")

    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax2,
                          line_alpha=0.5, components=[0,2], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax3,
                          line_alpha=0.5, components=[0,3], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax4,
                          line_alpha=0.5, components=[0,4], 
                          cell_color='cell_annotation', title="")

In [None]:
cluster_map = {0:'FP',1:'MN',2:'V3'}

traj_dict = {'MN':[],'V3':[],'FP':[]}

for ci in range(n_clusters):
    traj_dict[cluster_map[ci]].append(trajectories[labels==ci])
    
traj_dict = {key:torch.vstack(val) for key, val in traj_dict.items()}

with open('../data/F4_noisy0.025_trajectories.pickle', 'wb') as f:
    pickle.dump(traj_dict, f)

In [None]:
with open('../data/F4_noisy0.025_trajectories.pickle', 'rb') as f:
    data = pickle.load(f)
    
fig = plt.figure(figsize=(10,8), dpi=300)
ax1 = fig.subplots()

plot_trajectories = data['MN']

z = model.adata.obsm['X_z']
z = torch.tensor(z, device=model.device)
pca = PCA()
z_pca = pca.fit_transform(z.detach().cpu().numpy())

copy = model.adata.copy()
copy.obsm['X_vae'] = z_pca
copy.uns["velocity_params"] = {'embeddings':'vae'}

scv.pl.scatter(copy, basis='vae', color='cell_annotation', alpha=0.3,
              ax=ax1, size=1000, show=False, components="1,2", palette=colpal.celltype,
               legend_loc=False, title="", fontsize=16)


t_pca = []
for traj in plot_trajectories:
    t_pca.append(pca.transform(traj.detach().cpu().numpy()))


colors = generate_blue_shades(n=len(t_pca))
for t, c in zip(t_pca, colors):
    ax1.scatter(t[0,0],t[0,1], color='red', marker='x')
    ax1.plot(t[:,0],t[:,1], color=c, alpha=0.2, linewidth=2)
    
plt.tight_layout()

plt.savefig('../figures/3.0.2.noisy0.025_viz.png', bbox_inches='tight', transparent=True)

plt.show()

# CLEAN

In [None]:
sde_model.module.sde.noise_scalar = 0.0

trajectories, cell_ids = sde_model.simulate(
    initial_cells=initial_cells,
    n_samples_per_cell=1,
    n_steps = 100,
    t_max = 100,
    dt = 1.0,
    latent_key='X_z',
    n_chunks=10
)

sde_model.module.sde.noise_scalar = 0.1

In [None]:
n_clusters = 3

avg_center, labels, centers, index = vt.cl.cluster_trajectories(
    trajectories, 
    n_clusters=n_clusters, 
    final_steps=50,
    n_iterations=200
)

cluster_labels = labels.cpu().numpy()

In [None]:
for i in range(n_clusters):
    fig = plt.figure(figsize=(14,6), dpi=200)
    ax1, ax2, ax3, ax4 = fig.subplots(1,4)
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax1,
                          line_alpha=0.5, components=[0,1], show=False,
                          cell_color='cell_annotation', title=f"C{i}")

    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax2,
                          line_alpha=0.5, components=[0,2], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax3,
                          line_alpha=0.5, components=[0,3], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax4,
                          line_alpha=0.5, components=[0,4], 
                          cell_color='cell_annotation', title="")

In [None]:
cluster_map = {0:'V3',1:'MN',2:'FP'}

traj_dict = {'MN':[],'V3':[],'FP':[]}

for ci in range(n_clusters):
    traj_dict[cluster_map[ci]].append(trajectories[labels==ci])
    
traj_dict = {key:torch.vstack(val) for key, val in traj_dict.items()}

with open('../data/F4_noisy0.0_trajectories.pickle', 'wb') as f:
    pickle.dump(traj_dict, f)

In [None]:
with open('../data/F4_noisy0.0_trajectories.pickle', 'rb') as f:
    data = pickle.load(f)
    
fig = plt.figure(figsize=(10,8), dpi=300)
ax1 = fig.subplots()

plot_trajectories = data['MN']

z = model.adata.obsm['X_z']
z = torch.tensor(z, device=model.device)
pca = PCA()
z_pca = pca.fit_transform(z.detach().cpu().numpy())

copy = model.adata.copy()
copy.obsm['X_vae'] = z_pca
copy.uns["velocity_params"] = {'embeddings':'vae'}

scv.pl.scatter(copy, basis='vae', color='cell_annotation', alpha=0.3,
              ax=ax1, size=1000, show=False, components="1,2", palette=colpal.celltype,
               legend_loc=False, title="", fontsize=16)


t_pca = []
for traj in plot_trajectories:
    t_pca.append(pca.transform(traj.detach().cpu().numpy()))


colors = generate_blue_shades(n=len(t_pca))
for t, c in zip(t_pca, colors):
    ax1.scatter(t[0,0],t[0,1], color='red', marker='x')
    ax1.plot(t[:,0],t[:,1], color=c, alpha=0.2, linewidth=2)
    
plt.tight_layout()

plt.savefig('../figures/3.0.2.noisy0.0_viz.png', bbox_inches='tight', transparent=True)

plt.show()

# 0.2 

In [None]:
sde_model.module.sde.noise_scalar = 0.2

trajectories, cell_ids = sde_model.simulate(
    initial_cells=initial_cells,
    n_samples_per_cell=1,
    n_steps = 100,
    t_max = 100,
    dt = 1.0,
    latent_key='X_z',
    n_chunks=10
)

sde_model.module.sde.noise_scalar = 0.1

In [None]:
n_clusters = 7

avg_center, labels, centers, index = vt.cl.cluster_trajectories(
    trajectories, 
    n_clusters=n_clusters, 
    final_steps=50,
    n_iterations=200
)

cluster_labels = labels.cpu().numpy()

In [None]:
for i in range(n_clusters):
    fig = plt.figure(figsize=(14,6), dpi=200)
    ax1, ax2, ax3, ax4 = fig.subplots(1,4)
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax1,
                          line_alpha=0.5, components=[0,1], show=False,
                          cell_color='cell_annotation', title=f"C{i}")

    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax2,
                          line_alpha=0.5, components=[0,2], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax3,
                          line_alpha=0.5, components=[0,3], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, trajectories[cluster_labels==i], ax=ax4,
                          line_alpha=0.5, components=[0,4], 
                          cell_color='cell_annotation', title="")

In [None]:
cluster_map = {0:'FP',1:'V3',2:'MN',3:'MN',4:'FP',5:'V3',6:'V3'}

traj_dict = {'MN':[],'V3':[],'FP':[]}

for ci in range(n_clusters):
    traj_dict[cluster_map[ci]].append(trajectories[labels==ci])
    
traj_dict = {key:torch.vstack(val) for key, val in traj_dict.items()}

with open('../data/F4_noisy0.2_trajectories.pickle', 'wb') as f:
    pickle.dump(traj_dict, f)

In [None]:
with open('../data/F4_noisy0.2_trajectories.pickle', 'rb') as f:
    data = pickle.load(f)
    
fig = plt.figure(figsize=(10,8), dpi=300)
ax1 = fig.subplots()

plot_trajectories = data['MN']

z = model.adata.obsm['X_z']
z = torch.tensor(z, device=model.device)
pca = PCA()
z_pca = pca.fit_transform(z.detach().cpu().numpy())

copy = model.adata.copy()
copy.obsm['X_vae'] = z_pca
copy.uns["velocity_params"] = {'embeddings':'vae'}

scv.pl.scatter(copy, basis='vae', color='cell_annotation', alpha=0.3,
              ax=ax1, size=1000, show=False, components="1,2", palette=colpal.celltype,
               legend_loc=False, title="", fontsize=16)


t_pca = []
for traj in plot_trajectories:
    t_pca.append(pca.transform(traj.detach().cpu().numpy()))


colors = generate_blue_shades(n=len(t_pca))
for t, c in zip(t_pca, colors):
    ax1.scatter(t[0,0],t[0,1], color='red', marker='x')
    ax1.plot(t[:,0],t[:,1], color=c, alpha=0.2, linewidth=2)
    
plt.tight_layout()

plt.savefig('../figures/3.0.2.noisy0.2_viz.png', bbox_inches='tight', transparent=True)

plt.show()

### Random Markov Walk

In [None]:
mp = vt.sb.MarkovProcess(
    model,
    n_neighbors=10,
    use_space='latent_space',
    use_spline=True,
    use_similarity=False
)

In [None]:
z = torch.tensor(model.adata.obsm['X_z'], device=model.device)
init_indices = torch.arange(z.shape[0], device=model.device)[[c in initial_cells.obs_names for c in model.adata.obs_names]]

walks = mp.random_walk(
    z,
    initial_states=init_indices,
    n_jumps=50,
    n_steps=100,
)

In [None]:
n_clusters = 6
avg_center, labels, centers, index = vt.cl.cluster_trajectories(
    walks, 
    n_clusters=n_clusters, 
    final_steps=50,
    n_iterations=200
)

cluster_labels = labels.cpu().numpy()

In [None]:
for i in range(n_clusters):
    fig = plt.figure(figsize=(14,6), dpi=200)
    ax1, ax2, ax3, ax4 = fig.subplots(1,4)
    
    vt.pl.trajectories_2d(model, walks[cluster_labels==i], ax=ax1,
                          line_alpha=0.5, components=[0,1], show=False,
                          cell_color='cell_annotation', title=f"C{i}")

    vt.pl.trajectories_2d(model, walks[cluster_labels==i], ax=ax2,
                          line_alpha=0.5, components=[0,2], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, walks[cluster_labels==i], ax=ax3,
                          line_alpha=0.5, components=[0,3], show=False,
                          cell_color='cell_annotation', title="")
    
    vt.pl.trajectories_2d(model, walks[cluster_labels==i], ax=ax4,
                          line_alpha=0.5, components=[0,4], 
                          cell_color='cell_annotation', title="")

In [None]:
cluster_map = {0:'V3',1:'V3',2:'MN',3:'FP',4:'MN',5:'FP'}

traj_dict = {'MN':[],'V3':[],'FP':[]}

for ci in range(6):
    traj_dict[cluster_map[ci]].append(walks[labels==ci])
    
traj_dict = {key:torch.vstack(val) for key, val in traj_dict.items()}

with open('../data/F4_markov_trajectories.pickle', 'wb') as f:
    pickle.dump(traj_dict, f)


In [None]:
with open('../data/F4_markov_trajectories.pickle', 'rb') as f:
    data = pickle.load(f)
    
fig = plt.figure(figsize=(10,8), dpi=300)
ax1 = fig.subplots()

plot_trajectories = data['MN']

z = model.adata.obsm['X_z']
z = torch.tensor(z, device=model.device)
pca = PCA()
z_pca = pca.fit_transform(z.detach().cpu().numpy())

copy = model.adata.copy()
copy.obsm['X_vae'] = z_pca
copy.uns["velocity_params"] = {'embeddings':'vae'}

scv.pl.scatter(copy, basis='vae', color='cell_annotation', alpha=0.3,
              ax=ax1, size=1000, show=False, components="1,2", palette=colpal.celltype,
               legend_loc=False, title="", fontsize=16)


t_pca = []
for traj in plot_trajectories:
    t_pca.append(pca.transform(traj.detach().cpu().numpy()))


colors = generate_blue_shades(n=len(t_pca))
for t, c in zip(t_pca, colors):
    ax1.scatter(t[0,0],t[0,1], color='red', marker='x')
    ax1.plot(t[:,0],t[:,1], color=c, alpha=0.2, linewidth=2)
    
plt.tight_layout()

plt.savefig('../figures/3.0.3.markov_viz.png', bbox_inches='tight', transparent=True)

plt.show()

# pseudotime binning

In [None]:
model.infer_pseudotime()

In [None]:
pmn = model.adata[[m in ['Early_Neural','Neural','pMN'] for m in model.adata.obs.cell_annotation]]
pmn_cells = pmn[[t in ['D4','D5'] for t in pmn.obs.timepoint]].obs_names
neu_cells = model.adata[[m in ['MN'] for m in model.adata.obs.cell_annotation]].obs_names
keep = list(set(neu_cells).union(pmn_cells))
sub = model.adata[keep].copy()

In [None]:
sub2 = ann.AnnData(
    X=sub.obsm['X_z'],
    layers={"X": sub.obsm['X_z'], "velocity": sub.obsm['velocity_z']},
    obs=sub.obs.copy(),
)
vt.pp.neighbors(sub2, total_layer="X", n_neighbors=30, include_self=True)
scv.tl.velocity_graph(sub2, xkey="X", vkey="velocity", n_jobs=1)
scv.tl.velocity_pseudotime(sub2)
scv.pp.pca(sub2)

fig = plt.figure(figsize=(18, 6))
ax1, ax2, ax3 = fig.subplots(1, 3)

scv.pl.scatter(
    sub2,
    basis="pca",
    color="velocity_pseudotime",
    cmap="gnuplot",
    fontsize=22,
    size=40,
    legend_loc="on data",
    legend_fontsize=22,
    ax=ax1,
    show=False,
)
scv.pl.scatter(
    sub2,
    basis="pca",
    color="end_points",
    cmap="gnuplot",
    fontsize=22,
    size=40,
    legend_loc="on data",
    legend_fontsize=22,
    ax=ax2,
    show=False,
)
scv.pl.scatter(
    sub2,
    basis="pca",
    color="root_cells",
    cmap="gnuplot",
    fontsize=22,
    size=40,
    legend_loc="on data",
    legend_fontsize=22,
    ax=ax3,
)

sub.obs['t'] = sub2.obs.velocity_pseudotime
sub.obs["end_points"] = sub2.obs.end_points
sub.obs["root_cells"] = sub2.obs.root_cells

In [None]:
n_steps = 100
boundaries = [np.percentile(sub.obs.t.values, x) for x in np.linspace(0,100,n_steps+1)]

In [None]:
from tqdm import trange

traj = []

n = 10

for _ in trange(2000):
    series = []
    for i in range(n_steps):
        start = boundaries[i]
        end = boundaries[i+1]
        timesub = sub[(sub.obs.t>start).values&(sub.obs.t<=end).values]
        x = np.array(timesub.obsm['X_z'][np.random.choice(timesub.shape[0], size=n)])
        series.append(x.mean(0))
    traj.append(np.vstack(series)[None,:,:])

In [None]:
pt_bin = torch.tensor(np.vstack(traj), device=model.device)

In [None]:
vt.pl.trajectories_2d(model, pt_bin, line_alpha=0.5, components=[0,1])

In [None]:
import pickle

with open('../data/F4_pt10_bins.pickle', 'wb') as f:
    pickle.dump({'MN':pt_bin}, f)

In [None]:
from tqdm import trange

traj = []

n = 1

for _ in trange(2000):
    series = []
    for i in range(n_steps):
        start = boundaries[i]
        end = boundaries[i+1]
        timesub = sub[(sub.obs.t>start).values&(sub.obs.t<=end).values]
        x = np.array(timesub.obsm['X_z'][np.random.choice(timesub.shape[0], size=n)])
        series.append(x.mean(0))
    traj.append(np.vstack(series)[None,:,:])
    
pt_bin1 = torch.tensor(np.vstack(traj), device=model.device)

with open('../data/F4_pt1_bins.pickle', 'wb') as f:
    pickle.dump({'MN':pt_bin1}, f)

In [None]:
vt.pl.trajectories_2d(model, pt_bin1, line_alpha=0.5, components=[0,1])

In [None]:
with open('../data/F4_pt10_bins.pickle', 'rb') as f:
    data = pickle.load(f)

fig = plt.figure(figsize=(10,8), dpi=300)
ax1 = fig.subplots()

plot_trajectories = data['MN']

z = model.adata.obsm['X_z']
z = torch.tensor(z, device=model.device)
pca = PCA()
z_pca = pca.fit_transform(z.detach().cpu().numpy())

copy = model.adata.copy()
copy.obsm['X_vae'] = z_pca
copy.uns["velocity_params"] = {'embeddings':'vae'}

scv.pl.scatter(copy, basis='vae', color='cell_annotation', alpha=0.3,
              ax=ax1, size=1000, show=False, components="1,2", palette=colpal.celltype,
               legend_loc=False, title="", fontsize=16)


t_pca = []
for traj in plot_trajectories:
    t_pca.append(pca.transform(traj.detach().cpu().numpy()))


colors = generate_blue_shades(n=len(t_pca))
for t, c in zip(t_pca, colors):
    ax1.scatter(t[0,0],t[0,1], color='red', marker='x')
    ax1.plot(t[:,0],t[:,1], color=c, alpha=0.2, linewidth=2)
    
plt.tight_layout()

plt.savefig('../figures/3.0.2.p10_viz.png', bbox_inches='tight', transparent=True)

plt.show()

In [None]:
with open('../data/F4_pt1_bins.pickle', 'rb') as f:
    data = pickle.load(f)

fig = plt.figure(figsize=(10,8), dpi=300)
ax1 = fig.subplots()

plot_trajectories = data['MN']

z = model.adata.obsm['X_z']
z = torch.tensor(z, device=model.device)
pca = PCA()
z_pca = pca.fit_transform(z.detach().cpu().numpy())

copy = model.adata.copy()
copy.obsm['X_vae'] = z_pca
copy.uns["velocity_params"] = {'embeddings':'vae'}

scv.pl.scatter(copy, basis='vae', color='cell_annotation', alpha=0.3,
              ax=ax1, size=1000, show=False, components="1,2", palette=colpal.celltype,
               legend_loc=False, title="", fontsize=16)


t_pca = []
for traj in plot_trajectories:
    t_pca.append(pca.transform(traj.detach().cpu().numpy()))


colors = generate_blue_shades(n=len(t_pca))
for t, c in zip(t_pca, colors):
    ax1.scatter(t[0,0],t[0,1], color='red', marker='x')
    ax1.plot(t[:,0],t[:,1], color=c, alpha=0.2, linewidth=2)
    
plt.tight_layout()

plt.savefig('../figures/3.0.2.p1_viz.png', bbox_inches='tight', transparent=True)

plt.show()

In [None]:
with open('../data/F4_noisy0.0_trajectories.pickle', 'rb') as f:
    data = pickle.load(f)
    
fig = plt.figure(figsize=(10,8), dpi=300)
ax1 = fig.subplots()

plot_trajectories = data['MN']
plot_trajectories = plot_trajectories.median(0, keepdim=True).values

z = model.adata.obsm['X_z']
z = torch.tensor(z, device=model.device)
pca = PCA()
z_pca = pca.fit_transform(z.detach().cpu().numpy())

copy = model.adata.copy()
copy.obsm['X_vae'] = z_pca
copy.uns["velocity_params"] = {'embeddings':'vae'}

scv.pl.scatter(copy, basis='vae', color='cell_annotation', alpha=0.3,
              ax=ax1, size=1000, show=False, components="1,2", palette=colpal.celltype,
               legend_loc=False, title="", fontsize=16)


t_pca = []
for traj in plot_trajectories:
    t_pca.append(pca.transform(traj.detach().cpu().numpy()))


colors = generate_blue_shades(n=len(t_pca))
for t, c in zip(t_pca, colors):
    ax1.scatter(t[0,0],t[0,1], color='blue', marker='o', s=200)
    ax1.scatter(t[-1,0],t[-1,1], color='blue', marker='o', s=200)
    ax1.plot(t[:,0],t[:,1], color='blue', alpha=1, linewidth=6)

plt.tight_layout()

plt.savefig('../figures/3.0.5.average_viz.png', bbox_inches='tight', transparent=True)

plt.show()