In [118]:
import pandas as pd
import umap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# import umap.plot

import SEACells
import h5py
import collections
import scipy.sparse as sp_sparse
from scipy.sparse import csr_matrix
import scanpy as sc
import anndata

In [161]:
dataset_dir = '../data/'
save_dir = '../data/NGJ/'
dataset = "NGJ_experiment_2_filtered_feature_bc_matrix.h5"
switch_dataset = "switch_genelist.csv"
gene_length_dataset = "gene_len.csv"

# Read switches

In [120]:
switch_df = pd.read_csv(dataset_dir + switch_dataset, index_col=0)
# drop first column
switch_df = switch_df.drop(switch_df.columns[0], axis=1)

# switch df to dict
switch_dict = {}
for row in switch_df.index:
	switch_dict[row] = switch_df.loc[row].dropna().tolist()

# Read Reference Matrix

In [121]:
adata = sc.read_10x_h5(dataset_dir + dataset)
adata

  utils.warn_names_duplicates("var")


AnnData object with n_obs × n_vars = 1905 × 36601
    var: 'gene_ids', 'feature_types', 'genome'

In [122]:
# make var names unique
adata.var_names_make_unique()

In [123]:
raw_ad = sc.AnnData(adata.X)
raw_ad.obs_names, raw_ad.var_names = adata.obs_names, adata.var_names
adata.raw = raw_ad

In [124]:
# compute highly variable cells
sc.pp.normalize_per_cell(adata, copy=True)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=3660, subset=False) # n_top_genes = 36601 * 0.1
highly_variable_genes = adata.var['highly_variable']

In [125]:
# reset raw data
adata.X = adata.raw.X[:, :]

In [126]:
# load gene length data
gene_length = pd.read_csv(dataset_dir + gene_length_dataset).T
gene_length.columns = ['gene_length']
gene_length = gene_length.drop(gene_length.index[0])
gene_length.head()

Unnamed: 0,gene_length
ENSG00000121410,8315
ENSG00000148584,86267
ENSG00000175899,48566
ENSG00000166535,64381
ENSG00000184389,14333


In [127]:
# Extract gene expression matrix
gene_expr_matrix = adata.X
# Convert the gene expression matrix to a pandas DataFrame
gene_expr_df = pd.DataFrame.sparse.from_spmatrix(gene_expr_matrix, index=adata.obs_names, columns=adata.var['gene_ids']).T

# add gene name column
gene_expr_df['gene_name'] = adata.var_names

# Ensure both index names match for merging
gene_expr_df.index.name = 'gene_ids'
gene_length.index.name = 'gene_ids'

# Merge gene expression with gene lengths based on gene IDs
merged_data = pd.merge(gene_expr_df, gene_length, left_index=True, right_index=True)
merged_data.head()

Unnamed: 0_level_0,AAACCCAAGCAGCACA-1,AAACGAAAGACCGTTT-1,AAACGAAAGGGCGAAG-1,AAACGCTTCCTCGCAT-1,AAAGAACAGGGAGGAC-1,AAAGGATCATACCATG-1,AAAGGATTCTCGACGG-1,AAAGGGCTCAGCTAGT-1,AAAGGTACAGGTTCAT-1,AAAGGTAGTCCCAAAT-1,...,TTTGATCGTTATCTGG-1,TTTGGAGCAAATGATG-1,TTTGGAGGTATGTCAC-1,TTTGGAGGTGCCTGAC-1,TTTGGAGTCAAATGCC-1,TTTGGTTGTCACGTGC-1,TTTGTTGGTGACACAG-1,TTTGTTGTCAAATGAG-1,gene_name,gene_length
gene_ids,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ENSG00000186092,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,OR4F5,6167
ENSG00000284733,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,OR4F29,939
ENSG00000284662,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,OR4F16,939
ENSG00000187634,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,SAMD11,20653
ENSG00000188976,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,NOC2L,15107


In [128]:
# Calculate RPKM + 1
total_reads = merged_data.iloc[:, :-2].sum(axis=0)
rpkm = (merged_data.iloc[:, :-2].div(merged_data['gene_length'], axis=0) * 1e9).div(total_reads, axis=1) + 1

In [129]:
rpkm.index = merged_data['gene_name']
rpkm.head()

Unnamed: 0_level_0,AAACCCAAGCAGCACA-1,AAACGAAAGACCGTTT-1,AAACGAAAGGGCGAAG-1,AAACGCTTCCTCGCAT-1,AAAGAACAGGGAGGAC-1,AAAGGATCATACCATG-1,AAAGGATTCTCGACGG-1,AAAGGGCTCAGCTAGT-1,AAAGGTACAGGTTCAT-1,AAAGGTAGTCCCAAAT-1,...,TTTGATCCAATGCTCA-1,TTTGATCCATGACAAA-1,TTTGATCGTTATCTGG-1,TTTGGAGCAAATGATG-1,TTTGGAGGTATGTCAC-1,TTTGGAGGTGCCTGAC-1,TTTGGAGTCAAATGCC-1,TTTGGTTGTCACGTGC-1,TTTGTTGGTGACACAG-1,TTTGTTGTCAAATGAG-1
gene_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
OR4F5,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
OR4F29,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
OR4F16,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
SAMD11,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
NOC2L,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,8.826257,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [130]:
# save RPKM for future reference
rpkm.to_csv(save_dir + 'rpkm.csv')

  rpkm.to_csv(dataset_dir + 'rpkm.csv')


# Clustering pipeline

## Create anndata for genes in JS

In [132]:
REPRESENT = "gene"

In [133]:
# filter out highly variable genes
rpkm_filtered = rpkm.loc[highly_variable_genes]
rpkm_filtered.head()

Unnamed: 0_level_0,AAACCCAAGCAGCACA-1,AAACGAAAGACCGTTT-1,AAACGAAAGGGCGAAG-1,AAACGCTTCCTCGCAT-1,AAAGAACAGGGAGGAC-1,AAAGGATCATACCATG-1,AAAGGATTCTCGACGG-1,AAAGGGCTCAGCTAGT-1,AAAGGTACAGGTTCAT-1,AAAGGTAGTCCCAAAT-1,...,TTTGATCCAATGCTCA-1,TTTGATCCATGACAAA-1,TTTGATCGTTATCTGG-1,TTTGGAGCAAATGATG-1,TTTGGAGGTATGTCAC-1,TTTGGAGGTGCCTGAC-1,TTTGGAGTCAAATGCC-1,TTTGGTTGTCACGTGC-1,TTTGTTGGTGACACAG-1,TTTGTTGTCAAATGAG-1
gene_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SAMD11,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
HES4,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
ISG15,15.053525,1.0,1.0,1.0,1.0,21.762533,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,36.675882,1.0,1.0,15.52973,1.0,1.0
AGRN,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
RNF223,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [134]:
# take log2 of RPKM+1
rpkm_log2 = np.log2(rpkm_filtered)

In [135]:
# calculate reference genes
rab7a = rpkm.loc['RAB7A'].to_numpy()
gapdh = rpkm.loc['GAPDH'].to_numpy()
actp = rpkm.loc['ACTB'].to_numpy()

In [136]:
ref_norm = np.multiply(rab7a, gapdh, actp) ** (1/3)
ref_norm = np.log2(ref_norm)
ref_norm.min(), ref_norm.max()

(0.0, 4.893892346712768)

In [137]:
# calculate jane score:
# 10 + log2(RPKM) - log2(geomean(RAB7A+1, GAPDH+1, ACTB+1)) if RPKM > 0
# 0 if RPKM == 0
adata_processed_X = np.where(rpkm_log2 == 0, rpkm_log2, 10 + rpkm_log2 - ref_norm)

In [178]:
# save adata_processed_X as csv
pd.DataFrame(adata_processed_X, columns=rpkm_filtered.columns, index=rpkm_filtered.index).to_csv(save_dir + "js_gene.csv")

In [138]:
adata_proceesed = anndata.AnnData(csr_matrix(adata_processed_X.T, dtype=np.float32))
adata_proceesed.X

<1905x3314 sparse matrix of type '<class 'numpy.float32'>'
	with 458419 stored elements in Compressed Sparse Row format>

In [139]:
adata_proceesed.obs_names = rpkm_filtered.columns
adata_proceesed.var_names = rpkm_filtered.index

In [None]:
PCA_ncomps = 720

## Create anndata for switches in JS

In [162]:
REPRESENT = "switch"

In [163]:
# calculate reference genes
rab7a = rpkm.loc['RAB7A'].to_numpy()
gapdh = rpkm.loc['GAPDH'].to_numpy()
actp = rpkm.loc['ACTB'].to_numpy()

In [164]:
ref_norm = np.multiply(rab7a, gapdh, actp) ** (1/3)
ref_norm = np.log2(ref_norm)
ref_norm.min(), ref_norm.max()

(0.0, 4.893892346712768)

In [169]:
# take log2 of RPKM+1
rpkm_log2 = np.log2(rpkm)

In [173]:
# calculate jane score:
# 10 + log2(RPKM) - log2(geomean(RAB7A+1, GAPDH+1, ACTB+1)) if RPKM > 0
# 0 if RPKM == 0
js = pd.DataFrame(np.where(rpkm_log2 == 0, rpkm_log2, 10 + rpkm_log2 - ref_norm), columns=rpkm_log2.columns, index = rpkm_log2.index)

In [166]:
# compare genes in switches with genes in rpkm
switch_genes = [gene for switch in switch_dict.keys() for gene in switch_dict[switch]]
# change to set
switch_genes = set(switch_genes)
rpkm_genes = set(rpkm.index)

# find overlapping genes
overlapping_genes = switch_genes.intersection(rpkm_genes)
print(len(switch_genes), len(rpkm_genes), len(overlapping_genes))

5079 19212 4982


In [167]:
# filter non-overlapping genes from switches
for switch in switch_dict.keys():
	switch_dict[switch] = list(set(switch_dict[switch]).intersection(overlapping_genes))

In [174]:
# aggregate rpkm columns using switch_dict
js_switch = pd.DataFrame(columns=js.columns, index=switch_dict.keys())
for i in switch_dict.keys():
	js_switch.loc[i] = np.mean(js.loc[switch_dict[i]].values, axis=0)

In [176]:
# saving js switch data
js_switch.to_csv(save_dir + "js_switch.csv")

In [181]:
# convert data to AnnData
adata_proceesed = anndata.AnnData(csr_matrix(js_switch.T.values, dtype=np.float32))
adata_proceesed.X

<1905x489 sparse matrix of type '<class 'numpy.float32'>'
	with 492719 stored elements in Compressed Sparse Row format>

In [182]:
adata_proceesed.obs_names = js_switch.columns
adata_proceesed.var_names = js_switch.index

In [222]:
PCA_ncomps = 111

## SEACells Clustering and UMAP

In [223]:
sc.tl.pca(adata_proceesed, n_comps=PCA_ncomps)

In [224]:
cumsum = np.cumsum(adata_proceesed.uns['pca']['variance_ratio'])
print("#", np.count_nonzero(cumsum < 0.95), "PC's capture 95% of variance")

# 111 PC's capture 95% of variance


In [225]:
sc.pp.neighbors(adata_proceesed)
sc.tl.umap(adata_proceesed)
sc.pl.umap(adata_proceesed, save=f'_{REPRESENT}.png')



  cax = scatter(
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
  pl.show()


### Run SEACells

In [226]:
## Core parameters
n_SEACells = 15
build_kernel_on = 'X_pca'

## Additional parameters
n_waypoint_eigs = 10 # Number of eigenvalues to consider when initializing metacells !!not sure how to tune

In [227]:
model = SEACells.core.SEACells(adata_proceesed,
                  build_kernel_on=build_kernel_on,
                  n_SEACells=n_SEACells,
                  n_waypoint_eigs=n_waypoint_eigs,
                  convergence_epsilon = 1e-5)

Welcome to SEACells!


In [228]:
model.construct_kernel_matrix()
M = model.kernel_matrix

Computing kNN graph using scanpy NN ...
Computing radius for adaptive bandwidth kernel...


  0%|          | 0/1905 [00:00<?, ?it/s]

Making graph symmetric...
Parameter graph_construction = union being used to build KNN graph...
Computing RBF kernel...


  0%|          | 0/1905 [00:00<?, ?it/s]

Building similarity LIL matrix...


  0%|          | 0/1905 [00:00<?, ?it/s]

Constructing CSR matrix...


In [229]:
# Initialize archetypes
model.initialize_archetypes()

Building kernel on X_pca
Computing diffusion components from X_pca for waypoint initialization ... 
Determing nearest neighbor graph...
Done.
Sampling waypoints ...
Done.
Selecting 9 cells from waypoint initialization.
Initializing residual matrix using greedy column selection
Initializing f and g...


100%|██████████████████████████████████████████| 16/16 [00:00<00:00, 618.12it/s]

Selecting 6 cells from greedy initialization.





In [230]:
# Plot the initilization to ensure they are spread across phenotypic space
SEACells.plot.plot_initialization(adata_proceesed, model)

  plt.show()


In [231]:
model.fit(min_iter=10, max_iter=100)

Randomly initialized A matrix.
Setting convergence threshold at 0.00076
Starting iteration 1.
Completed iteration 1.
Starting iteration 10.
Completed iteration 10.
Converged after 18 iterations.


### Assessing results

In [232]:
# Check for convergence
model.plot_convergence()

  plt.show()


In [233]:
model.get_hard_assignments().head()

Unnamed: 0_level_0,SEACell
index,Unnamed: 1_level_1
AAACCCAAGCAGCACA-1,SEACell-7
AAACGAAAGACCGTTT-1,SEACell-9
AAACGAAAGGGCGAAG-1,SEACell-10
AAACGCTTCCTCGCAT-1,SEACell-5
AAAGAACAGGGAGGAC-1,SEACell-1


In [234]:
sc.pl.umap(adata_proceesed, color='SEACell', save=f'_SEACells_{REPRESENT}.png')



  cax = scatter(
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not found.
findfont: Font family 'Bitstream Vera Sans' not 

In [235]:
SEACells.plot.plot_SEACell_sizes(adata_proceesed, bins=10)


`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(label_df.groupby('SEACell').count().iloc[:, 0], bins=bins)
  plt.show()


Unnamed: 0_level_0,size
SEACell,Unnamed: 1_level_1
SEACell-0,427
SEACell-1,176
SEACell-2,57
SEACell-3,102
SEACell-4,24
SEACell-5,124
SEACell-6,113
SEACell-7,124
SEACell-8,262
SEACell-9,124


# Saving Results

In [236]:
model.get_hard_assignments().to_csv(save_dir + dataset + f'_SEACell_assignments_{REPRESENT}.csv', sep='\t')