In [1]:
import ot
import sgw
import scanpy as sc
import pandas as pd
import numpy as np
import networkx as nx
from geosketch import gs
import matplotlib.pyplot as plt

In [2]:
def plot_mapping(P,X_1,X_2, thresh=None, linewidth=1):

    fig = plt.figure()
    ax = plt.axes(projection='3d')
    
    ax.scatter3D(X_1[:,0], X_1[:,1], 0, c='tab:blue')
    ax.scatter3D(X_2[:,0], X_2[:,1], 1, c='tab:red')
    
    for i in range(P.shape[1]):
        if P[:,i].sum() < thresh: continue
        j = np.argmax(P[:,i])
        ax.plot3D([X_1[j,0], X_2[i,0]],[X_1[j,1], X_2[i,1]],[0,1],c='gray', linewidth=linewidth)
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.set_zticks([0, 1])

In [3]:
import plotly 
plotly_colors = plotly.colors.qualitative.Plotly

## **subset**

In [4]:
dataname = './data/dyntoy_bifurcating_3.h5ad'

In [5]:
adata = sc.read_h5ad(dataname)
print(adata)

AnnData object with n_obs × n_vars = 359 × 5460
    uns: 'milestone_network'
    obsm: 'milestone_percentages'


In [6]:
ms_names = np.array(adata.obsm['milestone_percentages'].columns.values, str)
adata.obs['milestones'] = pd.Series(list(ms_names[np.argmax(adata.obsm['milestone_percentages'],axis=1)]), dtype="category").values

In [None]:
n_pcs = 10
sc.tl.pca(adata)
sc.pp.neighbors(adata, n_neighbors=5, n_pcs=n_pcs)
sc.tl.tsne(adata, n_pcs=n_pcs)
X_tsne = adata.obsm['X_tsne']
milestones = np.sort(list(set(adata.obs['milestones'])))
for i in range(len(milestones)):
    idx = np.where(adata.obs['milestones']==milestones[i])[0]
    plt.scatter(X_tsne[idx,0], X_tsne[idx,1], c=plotly_colors[i], label=milestones[i])
plt.legend(loc=[1.01,0])
plt.axis('off')
plt.axis('equal')
plt.tight_layout()
plt.savefig("./figures/scatter_dataset.pdf")

  @numba.jit()
  @numba.jit()
  @numba.jit()
  from .autonotebook import tqdm as notebook_tqdm
  @numba.jit()
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [None]:
A = adata.obsp['distances']
G = nx.from_numpy_array(A)
ncell = A.shape[0]

In [None]:
p = dict(nx.shortest_path_length(G, weight='weight'))

In [None]:
D = np.array([[p[i][j] for i in range(ncell)] for j in range(ncell)])

In [None]:
adata.obsp['geodesics'] = D

In [None]:
X_pca = adata.obsm['X_pca'][:,:n_pcs]
# downsample_index = sgw.downsample_data(X_pca, gs_N=50, random_state=547, method='geosketch')
downsample_index = sgw.downsample_data(X_pca[:,:2], random_state=547, method='mapper')
print(len(downsample_index))

In [None]:
adata_sub = adata[downsample_index,:]

In [None]:
plt.scatter(adata.obsm['X_tsne'][:,0], adata.obsm['X_tsne'][:,1], label='full dataset', c='lightgrey')
plt.scatter(adata_sub.obsm['X_tsne'][:,0], adata_sub.obsm['X_tsne'][:,1], label='downsampled dataset', s=30, c='dimgrey')
plt.legend(loc=[1.01,0])
plt.axis('off')
plt.axis('equal')
plt.tight_layout()
plt.savefig("./figures/scatter_mapper.pdf")

In [None]:
adata_sub_1 = adata_sub[adata_sub.obs['milestones'].isin(['M1','M2','M3','M4']),:]
adata_sub_2 = adata_sub[adata_sub.obs['milestones'].isin(['M3']),:]
adata_1 = adata[adata.obs['milestones'].isin(['M1','M2','M3','M4']),:]
adata_2 = adata[adata.obs['milestones'].isin(['M3']),:]

In [None]:
plt.scatter(adata_sub_1.obsm['X_tsne'][:,0], adata_sub_1.obsm['X_tsne'][:,1], label='Dataset 1', c='tab:blue', s=90)
plt.scatter(adata_sub_2.obsm['X_tsne'][:,0], adata_sub_2.obsm['X_tsne'][:,1], label='Dataset 2', c='tab:red', s=40)
plt.legend(loc=[1.01,0])
plt.axis('off')
plt.axis('equal')
plt.tight_layout()
plt.savefig("./figures/scatter_downsample_twodatasets.pdf")

In [None]:
D_1 = adata_sub_1.obsp['geodesics']
D_2 = adata_sub_2.obsp['geodesics']

In [None]:
P = sgw.supervised_gromov_wasserstein(D_1, D_2, nitermax=20, threshold=10)

In [None]:
plot_mapping(P, adata_sub_1.obsm['X_tsne'], adata_sub_2.obsm['X_tsne'], thresh=0.01/adata_sub_2.obsm['X_tsne'].shape[0])
plt.savefig('./figures/coupling_subset_downsample.pdf')

In [None]:
DD1 = adata_1.obsp['geodesics'][:,[np.where(adata_1.obs_names==adata_sub_1.obs_names[i])[0][0] for i in range(adata_sub_1.shape[0])]]
DD2 = adata_2.obsp['geodesics'][:,[np.where(adata_2.obs_names==adata_sub_2.obs_names[i])[0][0] for i in range(adata_sub_2.shape[0])]]

In [None]:
P_full = sgw.recover_full_coupling(P,DD1,DD2,delta=0.45, thresh=10, eps=0.01, nitermax=1e4)

In [None]:
plot_mapping(P_full.toarray(),adata_1.obsm['X_tsne'], adata_2.obsm['X_tsne'], thresh=0.00000001, linewidth=0.5)
plt.savefig('./figures/coupling_subset_full.pdf')

## **partial overlap**

In [None]:
dataname = './data/dyntoy_bifurcating_3.h5ad'

In [None]:
adata = sc.read_h5ad(dataname)
print(adata)

In [None]:
ms_names = np.array(adata.obsm['milestone_percentages'].columns.values, str)
adata.obs['milestones'] = pd.Series(list(ms_names[np.argmax(adata.obsm['milestone_percentages'],axis=1)]), dtype="category").values

In [None]:
n_pcs = 10
sc.tl.pca(adata)
sc.pp.neighbors(adata, n_neighbors=5, n_pcs=n_pcs)
sc.tl.tsne(adata, n_pcs=n_pcs)
sc.pl.tsne(adata, color=['milestones'])

In [None]:
A = adata.obsp['distances']
G = nx.from_numpy_array(A)
ncell = A.shape[0]

In [None]:
p = dict(nx.shortest_path_length(G, weight='weight'))

In [None]:
D = np.array([[p[i][j] for i in range(ncell)] for j in range(ncell)])

In [None]:
adata.obsp['geodesics'] = D

In [None]:
X_pca = adata.obsm['X_pca'][:,:n_pcs]
# downsample_index = sgw.downsample_data(X_pca, gs_N=50, random_state=547, method='geosketch')
downsample_index = sgw.downsample_data(X_pca[:,:2], random_state=547, method='mapper')
print(len(downsample_index))

In [None]:
adata_sub = adata[downsample_index]
adata_sub_1 = adata_sub[adata_sub.obs['milestones'].isin(['M1','M3']),:]
adata_sub_2 = adata_sub[adata_sub.obs['milestones'].isin(['M3','M2','M4']),:]
adata_1 = adata[adata.obs['milestones'].isin(['M1','M3']),:]
adata_2 = adata[adata.obs['milestones'].isin(['M3','M2','M4']),:]

In [None]:
plt.scatter(adata_sub_1.obsm['X_tsne'][:,0], adata_sub_1.obsm['X_tsne'][:,1], label='Dataset 1', c='tab:blue', s=90)
plt.scatter(adata_sub_2.obsm['X_tsne'][:,0], adata_sub_2.obsm['X_tsne'][:,1], label='Dataset 2', c='tab:red', s=40)
plt.legend(loc=[1.01,0])
plt.axis('off')
plt.axis('equal')
plt.tight_layout()
plt.savefig("./figures/scatter_downsample_twodatasets_partialoverlap.pdf")

In [None]:
D_1 = adata_sub_1.obsp['geodesics']
D_2 = adata_sub_2.obsp['geodesics']
X_1 = adata_sub_1.obsm['X_tsne']
X_2 = adata_sub_2.obsm['X_tsne']

In [None]:
P = sgw.supervised_gromov_wasserstein(D_1, D_2, nitermax=20, threshold=5)

In [None]:
DD1 = adata_1.obsp['geodesics'][:,[np.where(adata_1.obs_names==adata_sub_1.obs_names[i])[0][0] for i in range(adata_sub_1.shape[0])]]
DD2 = adata_2.obsp['geodesics'][:,[np.where(adata_2.obs_names==adata_sub_2.obs_names[i])[0][0] for i in range(adata_sub_2.shape[0])]]

In [None]:
P_full = sgw.recover_full_coupling(P,DD1,DD2,delta=0.45, thresh=10, eps=0.01, nitermax=10000000)

In [None]:
plot_mapping(P, X_1, X_2, thresh=0.01/X_2.shape[0])
plt.savefig('./figures/coupling_partialoverlap_downsample.pdf')

In [None]:
plot_mapping(P_full.toarray(),adata_1.obsm['X_tsne'], adata_2.obsm['X_tsne'], thresh=0.00000001, linewidth=0.5)
plt.savefig('./figures/coupling_partialoverlap_full.pdf')