In [1]:
import sys
sys.path.append('../src/')

import os
import ot
import config
import logging
import numpy as np
import scanpy as sc
import pandas as pd
import benchmarking
import plotly.express as px
import matplotlib.pyplot as plt

from tqdm import tqdm
from refcm import RefCM
from embeddings import HVGEmbedder
from sklearn.metrics.pairwise import pairwise_distances


config.start_logging(logging.DEBUG)

  from .autonotebook import tqdm as notebook_tqdm
Seed set to 0


# Fig. 1a

We use the PBMC datasets, modified to reduce number of clusters and introduce merging/splitting, to illustrate the RefCM algorithm.

In [2]:
# load the datasets
q = sc.read_h5ad('../data/pbmc_Drop-Seq.h5ad') 
ref = sc.read_h5ad('../data/pbmc_CEL-Seq.h5ad') 

[h5py._conv      ] [DEBUG   ] : Creating converter from 3 to 5


In [3]:
# labels available in each
q.obs['labels'].unique().tolist(), ref.obs['labels'].unique().tolist()

(['CD4+ T cell',
  'Cytotoxic T cell',
  'Natural killer cell',
  'CD14+ monocyte',
  'CD16+ monocyte',
  'B cell',
  'Dendritic cell',
  'Megakaryocyte',
  'Plasmacytoid dendritic cell'],
 ['CD4+ T cell',
  'Cytotoxic T cell',
  'Natural killer cell',
  'CD16+ monocyte',
  'CD14+ monocyte',
  'Megakaryocyte',
  'B cell'])

In [4]:
# we will merge reference T cells as example of a coarser annotation level
ref.obs['labels'] = ref.obs['labels'].astype('str')
t_mask = ref.obs.labels.isin(['CD4+ T cell', 'Cytotoxic T cell'])
ref.obs.loc[t_mask, 'labels'] = 'T cell'

# keep only select query & reference cell types/clusters for simplicity
q_mask = q.obs.labels.isin(['Megakaryocyte', 'CD4+ T cell', 'Cytotoxic T cell', 'Natural killer cell'])
ref_mask = ref.obs.labels.isin(['B cell', 'T cell', 'Natural killer cell'])
 
q = q[q_mask]
ref = ref[ref_mask]

In [7]:
rcm = RefCM(cache_load=False, cache_save=False, max_merges=2)
m = rcm.annotate(q, 'pbmc_Drop-Seq', ref, 'pbmc_CEL-Seq', 'labels', 'labels')

# slightly lower type equality strictness for display, since "CD4+ T cell" is a subset of "T cells"
m.set_type_equality_strictness(0.7)

# evaluate, and display costs with results
m.display_matching_costs('labels', show_values=True, display_mapped_pairs=False)

In [None]:
# review final selected edges
m.display_matching_costs('labels')

# Fig. 1b

Here instead we take a deeper dive into the RefCM algorithm; specifically the OT step, and plot all the intermediate results.

## UMAPS & embedding

In [2]:
q = sc.read_h5ad('../data/pbmc_CEL-seq.h5ad') 
ref = sc.read_h5ad('../data/pbmc_Smart-Seq2.h5ad') 

labels = sorted(list(set(q.obs.labels.unique()) | set(ref.obs.labels.unique())))
palette = {k: sc.plotting.palettes.default_102[i] for i, k in enumerate(labels)}

sc.pp.normalize_total(q, 1e4)
sc.pp.normalize_total(ref, 1e4)

sc.pp.log1p(q)
sc.pp.log1p(ref)

sc.tl.pca(q)
sc.tl.pca(ref)

sc.pp.neighbors(q)
sc.pp.neighbors(ref)

sc.tl.umap(q)
sc.tl.umap(ref)

In [3]:
def umap(
    ds: sc.AnnData, key:str, palette: dict, title:str='', width: int = 750, height: int = 750, msize: int = 3
):
    # compute information to plot
    df = pd.DataFrame(
        {
            "x": ds.obsm["X_umap"][:, 0],
            "y": ds.obsm["X_umap"][:, 1],
            "label": ds.obs[key],
        }
    ).sort_values('label')

    # plot
    fig = px.scatter(
        df,
        x="x",
        y="y",
        color='label',
        color_discrete_map=palette,
        title=title
    )

    fig.update_layout(width=width, height=height, plot_bgcolor="white")
    fig.update_traces(marker=dict(size=msize))
    fig.update_legends(
        itemsizing="constant",
        title=None,
        orientation="v",
        xanchor="left",
        yanchor="top",
    )
    fig.update_xaxes(
        showgrid=False,
        showline=True,
        mirror=True,
        linecolor="black",
        showticklabels=False,
        title=None,
    )
    fig.update_yaxes(
        scaleanchor="x",
        scaleratio=1,
        showline=True,
        mirror=True,
        linecolor="black",
        showgrid=False,
        showticklabels=False,
        title=None,
    )
    
    return fig

In [4]:
fig1 = umap(q, 'labels', palette, title='CEL-Seq', width=600, height=500, msize=2)
fig2 = umap(ref, 'labels', palette, title='Smart-Seq2', width=600, height=500, msize=2)

In [20]:
fig1.update_layout(margin=dict(l=10, r=10, t=50, b=10), width=650, height=500, title='Query<br><sup>PBMC CEL-Seq</sup>')
fig1.update_traces(marker=dict(size=1.5))
fig1.write_image('fig1/CEL-Seq.png', scale=3)
fig1.show()

In [21]:
fig2.update_layout(margin=dict(l=10, r=10, t=50, b=10), width=650, height=500, title='Reference<br><sup>PBMC Smart-Seq2</sup>')
fig2.update_traces(marker=dict(size=1.5))
fig2.write_image('fig1/Smart-Seq2.png', scale=3)
fig2.show()

In [22]:
q = sc.read_h5ad('../data/pbmc_CEL-seq.h5ad') 
ref = sc.read_h5ad('../data/pbmc_Smart-Seq2.h5ad')

hvg = HVGEmbedder(max_cluster_size=None)
hvg.fit(q, ref)

q_embed = hvg.embed(q)
ref_embed = hvg.embed(ref)

# create a new anndata object for joint umap
X = np.concatenate((q_embed, ref_embed), axis=0)
obs = pd.concat((q.obs.labels, ref.obs.labels), axis=0)
obs = obs.to_frame(name='labels')

ad = sc.AnnData(X, obs)

# plot new umaps
sc.pp.normalize_total(ad, 1e4)
sc.pp.log1p(ad)
sc.tl.pca(ad)
sc.pp.neighbors(ad)
sc.tl.umap(ad)

# separate back out after joint embedding
q_embed = ad[:q_embed.shape[0]].copy()
ref_embed = ad[q_embed.shape[0]:].copy()

In [24]:
fig1 = umap(q_embed, 'labels', palette, title='Query <br><sup>joint embedding </sup>', width=600, height=500, msize=2)
fig2 = umap(ref_embed, 'labels', palette, title='Reference <br><sup>joint embedding </sup>', width=600, height=500, msize=2)


In [25]:
fig1.update_layout(margin=dict(l=10, r=10, t=50, b=10), width=650, height=500)
fig1.update_traces(marker=dict(size=1.5))
fig1.write_image('fig1/q_embed.png', scale=3)
fig1.show()

In [26]:
fig2.update_layout(margin=dict(l=10, r=10, t=50, b=10), width=650, height=500)
fig2.update_traces(marker=dict(size=1.5))
fig2.write_image('fig1/ref_embed.png', scale=3)
fig2.show()

## OT steps

### cluster pairings

In [27]:
q_labels = sorted(q_embed.obs.labels.unique())
ref_labels = sorted(ref_embed.obs.labels.unique())

In [170]:
os.makedirs("fig1", exist_ok=True)
for i, ql in enumerate(q_labels):
    for j, rl in enumerate(ref_labels):
        fig = umap(q_embed[q_embed.obs.labels==ql], 'labels', palette, width=250, height=250, msize=2)
        fig.update_layout(showlegend=False)
        fig.update_layout(
            margin=dict(l=10, r=10, t=10, b=10),
        )

        fig2 = umap(ref_embed[ref_embed.obs.labels==rl], 'labels', palette, width=250, height=250, msize=2)
        fig.add_trace(fig2.data[0])
        fig.write_image(f"fig1/embed_{i}_{j}.png")


For each, we display the euclidean distance between them.

In [28]:
def imshow(
    sbs, M, width: int = 250, height: int = 250, max_elts: int = 100
):
    zmin, zmax = M.min(), M.max()
    
    # subset for simplicity
    sbs = sbs[:max_elts, :max_elts]
    fig = px.imshow(sbs, zmin=zmin, zmax=zmax, width=width, height=height)

    fig.update_layout(plot_bgcolor="white", showlegend=False, margin=dict(l=10, r=10, t=10, b=10))
    fig.update_traces(showscale=False, coloraxis=None)
    
    fig.update_xaxes(
        showgrid=False,
        showticklabels=False,
        title=None,
    )
    fig.update_yaxes(
        scaleanchor="x",
        scaleratio=1,
        showgrid=False,
        showticklabels=False,
        title=None,
    )
    
    return fig

In [43]:
M = pairwise_distances(q_embed.X, ref_embed.X)

In [44]:
for i, ql in enumerate(q_labels):
    qmask = q_embed.obs.labels==ql
    for j, rl in enumerate(ref_labels):
        rmask = ref_embed.obs.labels==rl
        
        dists = M[qmask][:, rmask]
        
        fig = imshow(dists, M, max_elts=100)
        fig.write_image(f"fig1/dists_{i}_{j}.png")
        print(f'{i}, {j}')

0, 0
0, 1
0, 2
0, 3
0, 4
0, 5
1, 0
1, 1
1, 2
1, 3
1, 4
1, 5
2, 0
2, 1
2, 2
2, 3
2, 4
2, 5
3, 0
3, 1
3, 2
3, 3
3, 4
3, 5
4, 0
4, 1
4, 2
4, 3
4, 4
4, 5
5, 0
5, 1
5, 2
5, 3
5, 4
5, 5
6, 0
6, 1
6, 2
6, 3
6, 4
6, 5


Then, we display the optimal transport plan between each

In [31]:
from tqdm import tqdm

In [33]:
a.shape

(2362,)

In [35]:
b.shape

(2353,)

In [37]:
dists.shape

(850, 904)

In [45]:
unif = lambda s: np.ones(s) / s

tqdm_bar = "|{bar:16}| [{percentage:>6.2f}% ] : {elapsed}"
with tqdm(total=len(q_labels) * len(ref_labels), bar_format=tqdm_bar) as pbar:
    for i, ql in enumerate(q_labels):
        qmask = q_embed.obs.labels == ql
        x_qc = q_embed[qmask].X

        for j, rl in enumerate(ref_labels):
            rmask = ref_embed.obs.labels == rl
            x_rc = ref_embed[rmask].X

            a, b = unif(len(x_qc)), unif(len(x_rc))
            M = pairwise_distances(x_qc, x_rc)
            
            plan = ot.emd(a, b, M, numItermax=1e7)
            pbar.update(1)
            
            fig = imshow(plan, plan, max_elts=100)
            fig.write_image(f"fig1/plans_{i}_{j}.png")


|████████████████| [100.00% ] : 01:37


To retrieve and plot the emd costs, we can just go ahead and run RefCM directly.

In [46]:
q = sc.read_h5ad('../data/pbmc_CEL-seq.h5ad') 
ref = sc.read_h5ad('../data/pbmc_Smart-Seq2.h5ad') 

rcm = RefCM(cache_load=False, cache_save=False, pdist='euclidean')
m = rcm.annotate(q, 'pbmc_CEL-seq', ref, 'Smart-Seq2', 'labels', 'labels')

# evaluate, and display costs with results
m.eval('labels')
fig = m.display_matching_costs('labels', display_mapped_pairs=False, show_values=True, show_all_labels=True)
fig.show()

NOTE: raw counts expected in anndata .X attributes.
[refcm           ] [INFO    ] : NOTE: raw counts expected in anndata .X attributes.
|████████████████| [100.00% ] : 01:29
pbmc_CEL-seq         to Smart-Seq2          
[matchings       ] [INFO    ] : pbmc_CEL-seq         to Smart-Seq2          
6     common cell types
[matchings       ] [INFO    ] : 6     common cell types
6 /6  correct   links
[matchings       ] [INFO    ] : 6 /6  correct   links
0     incorrect links
[matchings       ] [INFO    ] : 0     incorrect links


In [52]:
fig.update_layout(width=1000, height=1000)

In [53]:
fig.write_image('fig1/res.png', scale=3)
