Goal: perform wavelet-based multi-view clustering (WMC) on new dataset

In [4]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc

plt.rcParams['savefig.dpi'] = 600
plt.rcParams['savefig.bbox'] = 'tight'
edgecolor_dict = {0: '#243078', 1: '#BEA2C4', 2: '#ACC9E5', 3: '#6AB5C0'}

Citation: Chen, Sijie (2024). Single-cell RNA-seq dataset of innate lymphoid cells. figshare. Dataset. https://doi.org/10.6084/m9.figshare.27190692.v1

In [7]:
adata = sc.read_h5ad('220516-ABM.velo.h5ad')

# Filter cells with fewer than 200 genes or more than 3,000 genes
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_cells(adata, max_genes=3000)
# Filter genes detected in fewer than 20 cells
sc.pp.filter_genes(adata, min_cells=20)

# Log normalization of each layer
def log_norm(adata):
  sc.pp.normalize_total(adata, target_sum=1e4, exclude_highly_expressed=True)
  sc.pp.log1p(adata)
log_norm(adata)
for layer in adata.layers:
  temp_adata = sc.AnnData(adata.layers[layer])
  log_norm(temp_adata)
  adata.layers[layer] = temp_adata.X

  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


mDWT
*   input: sparse matrix X where rows represent observations, integer M among 2, 3, or 4
*   output: M-band DWT decomposition of X

In [174]:
import pywt
import scipy.sparse as sp

def two_band(X):
  X = X.T
  n, m = X.shape
  nn = n + n % 2

  VV = np.empty((nn, m))
  WW = np.empty((nn, m))
  XX = np.zeros((nn, m))
  XX[:n, :] = X

  for i in range(m):
    p = XX[:, i]
    B = pywt.wavedec(p, 'db4', level=1)
    B1, B2 = B.copy(), B.copy()

    B1[1] = np.zeros_like(B1[1])
    B2[0] = np.zeros_like(B2[0])

    VV[:, i] = pywt.waverec(B1, 'db4')
    WW[:, i] = pywt.waverec(B2, 'db4')

  VV = VV[:n, :]
  WW = WW[:n, :]
  result = {"low": sp.csr_matrix(VV.T), "high1": sp.csr_matrix(WW.T)}
  return result

def three2WTM(n):
  if (n % 3 != 0):
    raise ValueError(str(n) + ' is not divisible by 3')
  # Filter banks
  h0 = np.array([0.33838609728386, 0.53083618701374, 0.72328627674361,
                 0.23896417190576, 0.04651408217589, -0.14593600755399])
  h1 = np.array([-0.11737701613483, 0.54433105395181, -0.01870574735313,
                 -0.69911956479289, -0.13608276348796, 0.42695403781698])
  h2 = np.array([0.40363686892892, -0.62853936105471, 0.46060475252131,
                 -0.40363686892892, -0.07856742013185, 0.24650202866523])
  # Matrix of filter banks created for convenience
  h = np.array([h0, h1, h2])

  k = n // 3
  W = sp.lil_matrix((n, n))
  for j in range(3):
    for i in range(k):
      if 3 * i + 6 > 3 * k:
        W[k * j + i, range((3 * i), (3 * i + 3))] = h[j, range(3)]
        W[k * j + i, range(3)] = h[j, range(3, 6)]
      else:
        W[k * j + i, range((3 * i), (3 * i + 6))] = h[j, range(6)]
  return W.tocsr()

def four2WTM(n):
  if (n % 4 != 0):
    raise ValueError(str(n) + ' is not divisible by 4')
  #Filter banks
  h0 = np.array([-0.067371764, 0.094195111, 0.40580489, 0.567371764,
                 0.567371764, 0.40580489, 0.094195111, -0.067371764])
  h1 = np.array([-0.094195111, 0.067371764, 0.567371764, 0.40580489,
                 -0.40580489, -0.567371764, -0.067371764, 0.094195111])
  h2 = np.array([-0.094195111, -0.067371764, 0.567371764, -0.40580489,
                 -0.40580489, 0.567371764, -0.067371764, -0.094195111])
  h3 = np.array([-0.067371764, -0.094195111, 0.40580489, -0.567371764,
                 0.567371764, -0.40580489, 0.094195111, 0.067371764])
  #Matrix of filter banks created for convenience
  h = np.array([h0, h1, h2, h3])

  k = n // 4
  W = sp.lil_matrix((n, n))
  for j in range(4):
    for i in range(k):
      if 4 * i + 8 > 4 * k:
        W[k * j + i, range((4 * i), (4 * i + 4))] = h[j, range(4)]
        W[k * j + i, range(4)] = h[j, range(4, 8)]
      else:
        W[k * j + i, range((4 * i), (4 * i + 8))] = h[j, range(8)]
  return W.tocsr()

def m_band(S, M):
  m, n = S.shape
  if (n % M != 0):
    S = sp.hstack([S, sp.csr_matrix((m, M - (n % M)))])
  W = {3: three2WTM, 4: four2WTM}[M](S.shape[1])

  # Directly obtain A1, D1, D2, D3
  k = S.shape[1] // M
  dwt_dict = {}
  for i in range(M):
    dwt = S @ W[i * k:(i + 1) * k].T.dot(W[i * k:(i + 1) * k])
    dwt_dict['low' if i == 0 else f'high{i}'] = dwt[:, :n]
  return dwt_dict

def mDWT(X, M):
  if M == 2:
    return two_band(X.toarray() if sp.issparse(X) else X)
  elif M == 3 or M == 4:
    return m_band(X, M)
  else:
    raise ValueError('Invalid band: M must be 2, 3, or 4')

def test_mDWT(M_values):
  fig, axs = plt.subplots(max(M_values) + 1, len(M_values), figsize=(10, 18))
  np.random.seed(0)
  X = np.random.rand(1, 100)
  for idx, M in enumerate(M_values):
    result = mDWT(X, M)
    for i, wt in zip(range(M + 1), [X, *result.values()]):
      ax = axs[i, idx]
      ax.plot((wt.toarray() if sp.issparse(wt) else wt).flatten())
      ax.set_title(i)
  plt.tight_layout()

Wavelet time analysis

In [175]:
import timeit

def mDWT_times(adata=adata, M_values=[2, 3, 4], step=100):
    n_range = [n + 1 for n in range(0, adata.X.shape[0], step)]
    mDWT_df = pd.DataFrame({
        f'{M}-Band': [
            min(
                timeit.repeat(
                    stmt=lambda: mDWT(adata.X[:n, :], M),
                    number=1
                )
            ) for n in n_range
        ] for M in M_values
    })
    mDWT_df.insert(0, 'Cells', n_range)
    mDWT_df.to_csv('WMC times.csv', index=False)

def plot_mDWT_times(mDWT_df):
    fig, ax = plt.subplots()
    for label in mDWT_df.columns[1:]:
        x = mDWT_df['Cells']
        y = mDWT_df[label]
        ax.scatter(x, y, label=label)
        m, b = np.polyfit(x, y, deg=1)
        ax.plot(x, m * x + b, label=f'_{label}')
    ax.set_xlabel('Number of Innate Lymphoid Cells')
    ax.set_ylabel('Time (seconds)')
    ax.legend(title='WMC')
    fig.savefig('WMC times.pdf')
    plt.close(fig)

In [214]:
mDWT_times()

In [219]:
plot_mDWT_times(pd.read_csv('WMC times.csv'))

Perform wavelet decomposition

In [176]:
dwt_dicts = {M: mDWT(adata.X, M) for M in [2, 3, 4]}

Perform dimension reduction

In [177]:
def dimreduce(adata):
  sc.pp.scale(adata)
  sc.tl.pca(adata)
  sc.pp.neighbors(adata, use_rep='X_pca')
  sc.tl.leiden(adata)
  sc.tl.umap(adata, n_components=3)

dimreduce(adata)
for M, dwt_dict in dwt_dicts.items():
  for key, val in dwt_dict.items():
    adata.layers[f'wmc_{M}band_{key}'] = val
    adata_dwt = sc.AnnData(X=val)
    dimreduce(adata_dwt)
    adata.obs[f'leiden_{M}band_{key}'] = adata_dwt.obs['leiden'].to_numpy()
    adata.obsm[f'X_umap_{M}band_{key}'] = adata_dwt.obsm[f'X_umap']
# Save DWT decomposition
adata.write_h5ad('220516-ABM.velo.wmc.h5ad')

Plot WMC results compared to UMAP without WMC

In [178]:
def wmc_results(adata, color='ann0608', cell='type', edgecolor_dict = edgecolor_dict):
  # Collect coordinates corresponding to WMC
  keys = [key for key in adata.obsm.keys() if 'X_umap_' in key]
  unique_M = sorted(set(int(key.split('_')[2][0]) for key in keys))
  fig, axes = plt.subplots(nrows=len(unique_M) + 1, ncols=unique_M[-1] + 1,
                           figsize=(10, 10), subplot_kw={'projection': '3d'})
  # Plot UMAP without WMC
  ax = axes[0, 1]
  sc.pl.umap(adata, color=color, ax=ax, show=False, title='No WMC',
             projection='3d')
  legend_labels = [t.get_text() for t in ax.get_legend().get_texts()]
  legend_handles = ax.get_legend().legend_handles
  ax.get_legend().remove()
  # Turn off blank subplots
  for row in range(axes.shape[0]):
    ax = axes[row, 0]
    ax.text(x=0.5, y=0.5, z=0.5, size='xx-large', ha='center', va='center',
            s='Without\nWMC' if row == 0 else f'{row + 1}-Band')
    ax.text(x=-0.65, y=0.7, z=1.4, s=chr(row + 65), size='xx-large', va='top')
    for col in [0, *range(row + 2, axes.shape[1])]:
      axes[row, col].axis('off')
  # Plot WMC figures
  for i, key in enumerate(keys):
    M = int(key.split('_')[2][0])
    freq = key.split("_")[3].capitalize()
    if len(freq) > 3:
      freq = freq[:-1] + ' ' + freq[-1]
    ax = axes[M - 1, 1 + (0 if freq == 'Low' else int(freq[-1]))]
    sc.pl.embedding(adata, basis=key, color=color, ax=ax, show=False,
                    legend_loc='none', title=freq, projection='3d')
  # Set axes
  for ax in axes.flatten():
    ax.set_title(ax.get_title(), pad=-40)
    ax.set_xlabel('UMAP 1', labelpad=-15)
    ax.set_ylabel('UMAP 2', labelpad=-15)
    ax.set_zlabel('UMAP 3', labelpad=-15)

  fig.legend(title=f'Cell {cell}', handles=legend_handles, labels=legend_labels,
             fontsize='x-large', title_fontsize='x-large')
  fig.tight_layout()

  # Add boxes around rows
  for row in range(axes.shape[0]):
    ax = axes[row, 0]
    box = ax.get_position()
    fig.add_artist(plt.Rectangle(
      (box.x0 - 0.5 * fig.subplotpars.left, box.y0 - 0.6 * fig.subplotpars.bottom),
      (box.width + 1.2 * fig.subplotpars.left) * (row + 2),
      box.height + 2 * fig.subplotpars.bottom,
      transform=fig.transFigure,
      fill=False,
      edgecolor=edgecolor_dict[row],
      linewidth=1 if row == 0 else 2
    ))
  fig.savefig(f'wmc_{cell}.pdf', transparent=True)
  plt.close(fig)

In [3]:
adata = sc.read_h5ad('220516-ABM.velo.wmc.h5ad')

In [138]:
wmc_results(adata)
wmc_results(adata, 'phase', cell='phase')

Plot dendrograms

In [6]:
import os
def plot_dendrograms(dir: str='DendrogramPlots', edgecolor_dict=edgecolor_dict, exts=['png', 'pdf']):
  plot_fnames = [f for f in os.listdir(dir) if f.endswith('.png')]
  fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(15, 12))
  for ax in axes.flatten():
    ax.axis('off')
  # Plot dendrograms
  for fname in plot_fnames:
    subplot = ''
    if fname[0] == 'N':
      title = 'No WMC'
      ax = axes[0, 0]
      subplot = 'A'
    else:
      M = int(fname[0])
      title = fname.split('_')[1].split('.')[0].capitalize().replace('High', 'High ')
      ax = axes[M - 2, 4 - M + (0 if title == 'Low' else int(title[-1]))]
      if title == 'Low':
         subplot = chr(M + 64)
    if subplot != '':
      ax.annotate(subplot, xy=(0, 1.05), xycoords='axes fraction', size='xx-large', va='top')
    ax.imshow(plt.imread(os.path.join(dir, fname)))
    ax.set_title(title)
  # Add space between rows
  fig.tight_layout(rect=[0, 0, 1, 1.05])
  # Add boxes around rows
  for M in range(axes.shape[1]):
    if M == 0:
      ax = axes[0, 0]
    else:
      ax = axes[M - 1, 3 - M]
    box = ax.get_position()
    fig.add_artist(plt.Rectangle(
      (box.x0 - fig.subplotpars.left, box.y0 - fig.subplotpars.bottom),
      (box.width + fig.subplotpars.left) * (M + 1),
      box.height + 3 * fig.subplotpars.bottom,
      transform=fig.transFigure,
      fill=False,
      edgecolor=edgecolor_dict[M],
      linewidth=1 if M == 0 else 2
    ))
  # Add arrows from No WMC to WMCs
  for M, xy, xytext in zip([2, 3, 4], [(-0.05, 0.8), (-0.05, 0.5), (0.5, 1.08)], [(0.85, 0.8), (0.5, 0), (0.5, 0)]):
    axes[M // 4, 1 - M // 4].annotate(
      f'{M}-Band',
      xy=xy,
      xytext=xytext,
      xycoords='axes fraction',
      textcoords='axes fraction',
      size='xx-large',
      ha='center',
      va='center',
      arrowprops=dict(arrowstyle="<-", color=edgecolor_dict[M - 1])
    )
  for ext in exts:
    fig.savefig(f'dendrograms.{ext}')
  plt.close(fig)

In [7]:
plot_dendrograms()

3D UMAP plots

In [3]:
import os
from mpl_toolkits.mplot3d import Axes3D

def plot_order(x):
    if x == 'No WMC':
        return 0
    elif 'Gap' in x:
        return int(x[-1])
    else:
        return 3 * (int(x[0]) - 1) + int(x.replace('Low', 'Low0')[-1])

def plot_umap(dir):
    umap_df = pd.read_csv(
        os.path.join(dir, 'UMAP.csv')
    ).rename(columns={'subtype': 'Subtype', 'celltype': 'Cell Type'})
    gaps = [f'Gap {n}' for n in [1, 4]]
    umap_df = pd.concat([umap_df, umap_df.iloc[:len(gaps), :].assign(WMC = gaps)])
    wmc_list = sorted(np.unique(umap_df['WMC']), key=plot_order)
    celltype_list = sorted(np.unique(umap_df['Cell Type']))
    fig, ax = plt.subplots(nrows=3, ncols=int(wmc_list[-1][0]), figsize=(15, 12),
                        subplot_kw={'projection': '3d'})
    ax = ax.flatten()
    for i, wmc in enumerate(wmc_list):
        if 'Gap' in wmc:
            ax[i].axis('off')
        else:
            ax[i].set_title(wmc, fontsize='xx-large')
            ax[i].set_xlabel('UMAP 1')
            ax[i].set_ylabel('UMAP 2')
            ax[i].set_zlabel('UMAP 3')
            wmc_mask = umap_df['WMC'] == wmc
            for celltype in celltype_list:
                celltype_mask = umap_df['Cell Type'] == celltype
                df = umap_df[wmc_mask & celltype_mask]
                ax[i].scatter(df['UMAP_1'], df['UMAP_2'], df['UMAP_3'], label=celltype)
    fig.legend(labels=celltype_list, loc='center left', fontsize='xx-large')
    fig.tight_layout(h_pad=5, w_pad=3)
    fig.savefig(f'UMAP 3D {dir}.pdf')
    plt.close(fig)

In [2]:
dirs_umap = ['CID3921', 'CID4463', 'CID4495', 'CID4523']

In [88]:
for dir in dirs_umap:
    plot_umap(dir)

In [59]:
from sklearn.cluster import DBSCAN
clusters = []
for dir in dirs_umap:
    umap_df = pd.read_csv(
        os.path.join(dir, 'UMAP.csv')
    ).rename(columns={'subtype': 'Subtype', 'celltype': 'Cell Type'})
    wmc_list = sorted(np.unique(umap_df['WMC']), key=plot_order)
    celltype_list = sorted(np.unique(umap_df['Cell Type']))
    for wmc in wmc_list:
        wmc_mask = umap_df['WMC'] == wmc
        for celltype in celltype_list:
            celltype_mask = umap_df['Cell Type'] == celltype
            X = umap_df[wmc_mask & celltype_mask].iloc[:, :3]
            _, counts = np.unique(DBSCAN().fit_predict(X), return_counts=True)
            clusters.append({'Dataset': dir, 'WMC': wmc,
                             'Cell Type': celltype, 'Size': max(counts)})
pd.DataFrame(clusters).to_csv('UMAP 3D Cluster Sizes.csv', index=False)

Leiden ARI calculation

In [259]:
from sklearn.metrics import adjusted_rand_score
ARI_list = []
for cell in ['ann0608', 'phase']:
    for l in adata.obs.columns:
        if 'leiden' not in l:
            continue
        ari = adjusted_rand_score(adata.obs[cell], adata.obs[l])
        ARI_list.append({'Cell': 'Phase' if cell == 'phase' else 'Type',
                         'WMC': l.replace('leiden', ''), 'ARI': ari})
ARI_df = pd.DataFrame(ARI_list)
wmc_replacements = {
    '_': '',
    'band': '-Band\n',
    'low': 'Low',
    'high': 'High '
}
ARI_df['WMC'] = ARI_df['WMC'].replace('', 'No WMC')
for old_string, new_string in wmc_replacements.items():
    ARI_df['WMC'] = ARI_df['WMC'].str.replace(old_string, new_string)

In [269]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

plt.figure(figsize=(8, 5))
barplot = sns.barplot(x='WMC', y='ARI', hue='Cell', data=ARI_df)
for i, bar in enumerate(barplot.patches):
  bar_height = bar.get_height()
  if bar_height <= 0.05:
    continue
  barplot.text(bar.get_x() + bar.get_width() / 2, bar_height,
               f'  {bar_height:.2f}', ha='center', va='bottom')
plt.xlabel('WMC')
plt.ylabel('Leiden ARI')
plt.legend(title='Cell')
plt.savefig('Leiden ARI.pdf')
plt.close('all')