In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
import umap

import torch

from scmg.model.contrastive_embedding import (CellEmbedder, 
                                        decode_cell_state_embedding)

from scmg.model.manifold_generation import ConditionalDiffusionModel, generate_transition_cells
from scmg.preprocessing.data_standardization import GeneNameMapper
gene_name_mapper = GeneNameMapper()


In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)
plt.rcParams['axes.grid'] = False

In [None]:
# Load the autoencoder model
model_ce_path = '../../../contrastive_embedding/trained_embedder/'

model_ce = torch.load(os.path.join(model_ce_path, 'model.pt'))
model_ce.load_state_dict(torch.load(os.path.join(model_ce_path, 'best_state_dict.pth')))

device = 'cuda:0'
model_ce.to(device)
model_ce.eval()

In [None]:
# Load the diffusion model
model_d_path = '../../trained_diffusion_model'

model_d = torch.load(os.path.join(model_d_path, 'model.pt'))
model_d.load_state_dict(torch.load(os.path.join(model_d_path, 'best_state_dict.pth')))

device = 'cuda:0'
model_d.to(device)
model_d.eval()

In [None]:
traj_cell_types = ['Epiblast', 'Primitive streak and adjacent ectoderm', 'Nascent mesoderm', 'Mixed mesoderm',
                   'Paraxial mesoderm A', 'Dermomyotome', 'Myocytes', 'fast muscle cell']

generated_cells = []
cond_classes = []

for i in range(len(traj_cell_types) - 1):
    start_cell_type = traj_cell_types[i]
    end_cell_type = traj_cell_types[i + 1]
    
    local_generated_cells, local_cond_classes = generate_transition_cells(model_d, start_cell_type, end_cell_type, 500)

    for j in range(len(local_generated_cells)):
        generated_cells.append(local_generated_cells[j])
        cond_classes.append(local_cond_classes[j])

generated_cells = np.array(generated_cells)
cond_classes = np.array(cond_classes)

In [None]:
adata_generated = decode_cell_state_embedding(model_ce, generated_cells, ['Qiu_Organogenesis_MM_2022:all'] * len(generated_cells))
adata_generated.obs['cell_type'] = cond_classes

adata_generated.var.index = gene_name_mapper.map_gene_names(
        adata_generated.var.index, 'human', 'human', 'id', 'name')
adata_generated.var_names_make_unique()

adata_generated

In [None]:
sc.pp.neighbors(adata_generated, use_rep='X_ce_latent', n_neighbors=50)

sc.tl.paga(adata_generated, groups='cell_type')
sc.pl.paga(adata_generated, plot=False)  
sc.tl.umap(adata_generated, init_pos='paga', min_dist=0.5)

#sc.tl.umap(adata_generated, min_dist=0.5)

In [None]:
adata_generated.obs['cell_type'].cat.categories

In [None]:
from sklearn.metrics.pairwise import pairwise_distances
#terminal_ct = 'Nascent mesoderm'
terminal_ct = 'Mixed mesoderm'

ref_points = adata_generated[adata_generated.obs['cell_type'] == 
                        terminal_ct].obsm['X_ce_latent']

dists = pairwise_distances(adata_generated.obsm['X_ce_latent'], ref_points,
                                           metric='euclidean')
adata_generated.obs['terminal_ct_dist'] = np.mean(dists, axis=1)

In [None]:
root_cell = adata_generated.obs[adata_generated.obs["cell_type"] == "Epiblast"
                    ].sort_values('terminal_ct_dist', ascending=False).iloc[0].name

adata_generated.uns["iroot"] = adata_generated.obs.index.get_loc(root_cell)

sc.tl.dpt(adata_generated)
adata_generated.obs['dpt_rele_rank'] = adata_generated.obs['dpt_pseudotime'].rank() / adata_generated.shape[0]

In [None]:
sc.pl.umap(adata_generated, color='INS')

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(3, 3), dpi=300)
sc.pl.umap(adata_generated, color='cell_type',
           legend_loc='on data', legend_fontsize=4, ax=ax, frameon=False, title='')

fig.savefig('umap_cell_types.pdf', dpi=300)

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(3, 3), dpi=300)
sc.pl.umap(adata_generated, color='dpt_rele_rank', title='pseudotime', cmap='jet', ax=ax, frameon=False)

fig.savefig('umap_pseudotime.pdf', dpi=300)

In [None]:
ct_color_map = {ct : c for ct, c in 
    zip(adata_generated.obs['cell_type'].cat.categories,
    adata_generated.uns['cell_type_colors'])}

adata_generated.obs['cell_type_color'] = adata_generated.obs['cell_type'].map(
    ct_color_map)

In [None]:
X_normalized = adata_generated.X / adata_generated.X.sum(axis=0)[None, :]
adata_generated.var['mean'] = adata_generated.X.mean(axis=0)
adata_generated.var['gene_mean_time'] = np.sum(X_normalized * adata_generated.obs['dpt_rele_rank'].values[:, None], axis=0)

cell_order = adata_generated.obs['dpt_pseudotime'].sort_values().index.values
gene_order = adata_generated.var['gene_mean_time'].sort_values().index.values
adata_ordered = adata_generated[cell_order, gene_order].copy()

#adata_ordered.X = (adata_ordered.X - gene_stats_df['mean'].values[None, :]) / np.maximum(0.1, gene_stats_df['std'].values)[None, :]
adata_ordered.X = adata_ordered.X / adata_ordered.X.mean(axis=0)[None, :]

In [None]:
sc.pp.highly_variable_genes(adata_generated, n_top_genes=5000)
adata_generated.var['max'] = adata_generated.X.max(axis=0)
hv_genes = adata_generated.var_names[
    adata_generated.var['highly_variable'] 
    #& (adata_generated.var['mean'] > 0.1)
    #& (adata_generated.var['max'] > 1)
]

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)

X = adata_ordered[:, adata_ordered.var.index.isin(hv_genes)].X.T

ax.imshow(X, vmax=5, cmap='inferno_r', aspect=X.shape[1] / X.shape[0])
ax.set_xticks(np.linspace(0, adata_ordered.shape[0], num=11), 
              [f'{x:.1f}' for x in np.linspace(0, 1, num=11)], 
              size=8)
ax.set_yticks(np.linspace(0, X.shape[0], num=11), 
              np.linspace(0, X.shape[0], num=11, dtype=int),
              size=8)

ax.set_xlabel('pseudotime')
ax.set_ylabel('genes')

fig.savefig('gene_expression_heatmap.pdf')

In [None]:
adata_ordered[:, adata_ordered.var.index.isin(hv_genes)].var.to_csv('var_ordered.csv')

In [None]:
genes_to_print = adata_ordered[:, adata_ordered.var.index.isin(hv_genes)].var.index[1500:4000]

n_per_line = 20

for i in range(len(genes_to_print) // n_per_line + 1):
    for j in range(n_per_line):
        if i * n_per_line + j < len(genes_to_print):
            print("'" + genes_to_print[i * n_per_line + j] + "'", end=',')
    print()