In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import moscot
from anndata import AnnData
import numpy as np
from moscot.backends.ott import FGWSolver, SinkhornSolver, GWSolver
from moscot.solvers._tagged_arry import TaggedArray, Tag
import matplotlib.pyplot as plt
import jax.numpy as jnp
from typing import *
import pandas as pd
import networkx as nx
import jax.numpy as jnp
from moscot.problems.time._lineage import TemporalProblem
import os
import wot
from ott.geometry import pointcloud

In [3]:
import scanpy as sc
import matplotlib as mpl

In [4]:
import ott
import ot
from tqdm import tqdm

In [5]:
data_path = "/home/icb/dominik.klein/data/wot"


VAR_GENE_DS_PATH = os.path.join(data_path, 'ExprMatrix.var.genes.h5ad')
CELL_DAYS_PATH = os.path.join(data_path, 'cell_days.txt')
SERUM_CELL_IDS_PATH = os.path.join(data_path, 'serum_cell_ids.txt')
CELL_GROWTH_PATH = os.path.join(data_path, 'growth_gs_init.txt')
FULL_DS_PATH = os.path.join(data_path, 'ExprMatrix.h5ad')
CELL_DAYS_PATH = os.path.join(data_path, 'cell_days.txt')
VAR_DS_PATH = os.path.join(data_path, 'ExprMatrix.var.genes.h5ad')
TMAP_PATH = os.path.join(data_path, 'serum')
CELL_SETS_PATH = os.path.join(data_path, 'major_cell_sets.gmt')
COORDS_PATH = os.path.join(data_path, 'fle_coords.txt')

In [6]:
adata = wot.io.read_dataset(VAR_GENE_DS_PATH, obs=[CELL_DAYS_PATH, CELL_GROWTH_PATH], obs_filter=SERUM_CELL_IDS_PATH)
adata.shape

(175472, 1479)

In [7]:
adata

AnnData object with n_obs × n_vars = 175472 × 1479
    obs: 'day', 'cell_growth_rate'

In [8]:
sc.tl.pca(adata)

In [9]:
CELL_SETS_PATH = os.path.join(data_path, 'major_cell_sets.gmt')
cell_sets = wot.io.read_sets(CELL_SETS_PATH, as_dict=True)
cell_to_type = {v[i]: k for k, v in cell_sets.items() for i in range(len(v))}
df_cell_type = pd.DataFrame(cell_to_type.items(), columns=["0", "cell_type"]).set_index("0")
adata.obs = pd.merge(adata.obs, df_cell_type, how="left", left_index=True, right_index=True)

In [10]:
days = list(adata.obs.day.unique())
days.pop(-1) # remove nan

nan

In [11]:
pc_cost_matrix = pointcloud.PointCloud(adata[adata.obs.day==days[0]].obsm["X_pca"], adata[adata.obs.day==days[1]].obsm["X_pca"]).cost_matrix

In [12]:
ot_cost_matrix = ot.dist(adata[adata.obs.day==days[0]].obsm["X_pca"], adata[adata.obs.day==days[1]].obsm["X_pca"])

In [13]:
xs = {}
for i in tqdm(range(len(days)-1)):
    c = ot.dist(adata[adata.obs.day==days[i]].obsm["X_pca"], adata[adata.obs.day==days[i+1]].obsm["X_pca"])
    #c /= np.median(np.array(c))
    adata.uns["cost_matrix_{}".format(i)] = c
    xs[(days[i], days[i+1])] = {"attr": "uns", "key": "cost_matrix_{}".format(i), "tag": Tag.COST_MATRIX}

100%|██████████| 38/38 [00:05<00:00,  6.57it/s]


## Run WOT

In [14]:
#ot_model = wot.ot.OTModel(adata,epsilon = 0.05, lambda1 = 1,lambda2 = 50) 

In [15]:
#tmap_annotated = ot_model.compute_transport_map(11,11.5)

## Run Moscot

In [16]:
tp = TemporalProblem(adata, solver=SinkhornSolver(jit=False))

In [17]:
adata.obs.day = adata.obs.day.astype('category')

In [18]:
tp.prepare("day", x=xs, a_marg={"attr": "obs", "key":"cell_growth_rate"})

In [19]:
lambda_1 = 1.0
#lambda_2 = 50.0
eps = 5.0
tp.solve(eps=eps, tau_a=lambda_1/(lambda_1+eps), tau_b=1)

## Pull back mass

In [20]:
adata[adata.obs.day==18.0].obs.cell_type.value_counts(dropna=False)

Stromal        1239
Neural          819
NaN             534
Epithelial      503
IPS             412
Trophoblast     292
Name: cell_type, dtype: int64

In [21]:
adata[adata.obs.day==18.0]

View of AnnData object with n_obs × n_vars = 3799 × 1479
    obs: 'day', 'cell_growth_rate', 'cell_type'
    uns: 'pca', 'cost_matrix_0', 'cost_matrix_1', 'cost_matrix_2', 'cost_matrix_3', 'cost_matrix_4', 'cost_matrix_5', 'cost_matrix_6', 'cost_matrix_7', 'cost_matrix_8', 'cost_matrix_9', 'cost_matrix_10', 'cost_matrix_11', 'cost_matrix_12', 'cost_matrix_13', 'cost_matrix_14', 'cost_matrix_15', 'cost_matrix_16', 'cost_matrix_17', 'cost_matrix_18', 'cost_matrix_19', 'cost_matrix_20', 'cost_matrix_21', 'cost_matrix_22', 'cost_matrix_23', 'cost_matrix_24', 'cost_matrix_25', 'cost_matrix_26', 'cost_matrix_27', 'cost_matrix_28', 'cost_matrix_29', 'cost_matrix_30', 'cost_matrix_31', 'cost_matrix_32', 'cost_matrix_33', 'cost_matrix_34', 'cost_matrix_35', 'cost_matrix_36', 'cost_matrix_37'
    obsm: 'X_pca'
    varm: 'PCs'

In [22]:
#for i in range(38):
#    print(adata.uns[f"cost_matrix_{i}"].shape)

In [None]:
result = tp.pull_back_composed(start=0, end=18, key_groups="cell_type", groups=["IPS"], return_all=True)

## Compare

In [None]:
COORD_DF = pd.read_csv(COORDS_PATH, sep='\t', index_col=0)

In [None]:
nbins = 500
xrange = COORD_DF['x'].min(), COORD_DF['x'].max()
yrange = COORD_DF['y'].min(), COORD_DF['y'].max()
COORD_DF['x'] = np.floor(
    np.interp(COORD_DF['x'], [xrange[0], xrange[1]], [0, nbins - 1])).astype(int)
COORD_DF['y'] = np.floor(
    np.interp(COORD_DF['y'], [yrange[0], yrange[1]], [0, nbins - 1])).astype(int)

In [None]:
bdata = adata.copy()
bdata.obs = bdata.obs.join(COORD_DF)

In [None]:
percentile_thresholds = [90, 80]
alpha_bins = np.flip(np.linspace(0,1,len(percentile_thresholds)+1))

def bin_alpha(x, thresholds, alpha_bins):
    for i in range(len(thresholds)):
        if x >= thresholds[i]:
            return alpha_bins[-(i+1)]
    
days_reverse = days[::-1]
bdata.obs["alpha_bin"] = np.nan
for i in tqdm(range(len(result))):
    bdata_filtered = bdata[bdata.obs.day==days_reverse[i]]
    assert len(bdata_filtered) == len(result[i])
    thresholds = np.percentile(result[i], percentile_thresholds)+1e-8
    bdata.obs.loc[bdata.obs.day==days_reverse[i], "alpha_bin"] = alpha_bins[np.digitize(result[i], thresholds)]
    #bdata.obs.loc[bdata.obs.day==days_reverse[i], "alpha_bin"] = list(map(partial(bin_alpha, thresholds=thresholds, alpha_bins=alpha_bins), result[i]))

In [None]:
bdata.obs.alpha_bin.value_counts()

In [None]:
cm = plt.get_cmap('jet')
cNorm  = mpl.colors.Normalize(vmin=0, vmax=len(result))
scalarMap = mpl.cm.ScalarMappable(norm=cNorm, cmap=cm)

fig = plt.figure(figsize=(13, 10))
plt.title(f'Cell type: iPSC, medium: serum, $\epsilon$={5}', fontsize=24)
plt.plot(COORD_DF['x'], COORD_DF['y'], marker='.', color='grey', ls='',
         markersize=0.3, alpha=0.07)
for i in range(len(result)):
    colorVal = scalarMap.to_rgba(i)
    for b in alpha_bins:
        colorVal = np.array(colorVal)
        colorVal[3] = b
        plt.plot(bdata.obs.loc[np.logical_and(bdata.obs["alpha_bin"]==b, bdata.obs["day"]==days[i]), "x"],
                 bdata.obs.loc[np.logical_and(bdata.obs["alpha_bin"]==b, bdata.obs["day"]==days[i]), "y"],
               marker='.', color=colorVal, ls='', markersize=1)
plt.xlabel('FLE1', fontsize=24)
plt.ylabel('FLE2', fontsize=24)
ax, _ = mpl.colorbar.make_axes(plt.gca(), shrink=1)
cbar = mpl.colorbar.ColorbarBase(ax, cmap=cm,
                       norm=mpl.colors.Normalize(vmin=0, vmax=18))
plt.show()