In [None]:
import anndata as ad
import matplotlib.pyplot as plt
import metacells as mc
import numpy as np
import os
import pandas as pd
import scipy.sparse as sp
import seaborn as sb

from math import hypot
from matplotlib.collections import LineCollection
from IPython.display import set_matplotlib_formats

set_matplotlib_formats('svg')
sb.set_style("white")

In [None]:
raw = ad.read_h5ad('output/scrna_db/mm_cells_with_md.h5ad')
# raw = mc.ut.slice(raw, vars=np.squeeze(np.asarray(np.sum(raw.X, axis = 0))) > 100)
# exe_cells = set(i for j in pd.read_table("../rabemb/exe_cells_fixed.txt", header=None).values for i in j)
# raw = mc.ut.slice(raw, obs=[i not in exe_cells for i in raw.obs_names])

In [None]:
mc.ut.set_name(raw, 'embexe')
print(raw.shape)

In [None]:
excluded_gene_names = ["Neat1", "Xist", "Malat1", "AK140265","AK018753","AK163440","DQ539915","AK131586","AK131579","AK142750","X57780","GU332589","BC071253"]
excluded_gene_patterns = [
#     '^IGJ', '^IGH', '^IGK', '^IGL', 'MT-*', "^MTMR*", '^MTRNR*', '^MTND*',  
    'ERCC*', 
#     'hotspot*', 
#     'LOC*'
]

In [None]:
%%time
mc.pl.analyze_clean_genes(raw,
                          excluded_gene_names=excluded_gene_names,
                          excluded_gene_patterns=excluded_gene_patterns,
                          random_seed=123456)

In [None]:
%%time
mc.pl.pick_clean_genes(raw)

In [None]:
from collections import Counter
Counter(raw.var.clean_gene)

In [None]:
%%time
# raw.write('full_embexe.h5ad')
full = raw

In [None]:
%%time
properly_sampled_min_cell_total = 2400
properly_sampled_max_cell_total = 32000

total_umis_of_cells = mc.ut.get_o_numpy(full, name='__x__', sum=True)

plot = sb.distplot(total_umis_of_cells)
plot.set(xlabel='UMIs', ylabel='Density', yticks=[])
plot.axvline(x=properly_sampled_min_cell_total, color='darkgreen')
plot.axvline(x=properly_sampled_max_cell_total, color='crimson')

too_small_cells_count = sum(total_umis_of_cells < properly_sampled_min_cell_total)
too_large_cells_count = sum(total_umis_of_cells > properly_sampled_max_cell_total)

too_small_cells_percent = 100.0 * too_small_cells_count / len(total_umis_of_cells)
too_large_cells_percent = 100.0 * too_large_cells_count / len(total_umis_of_cells)

print(f"Will exclude %s (%.2f%%) cells with less than %s UMIs"
      % (too_small_cells_count,
         too_small_cells_percent,
         properly_sampled_min_cell_total))
print(f"Will exclude %s (%.2f%%) cells with more than %s UMIs"
      % (too_large_cells_count,
         too_large_cells_percent,
         properly_sampled_max_cell_total))

In [None]:
np.median(total_umis_of_cells)

In [None]:
%%time
properly_sampled_max_excluded_genes_fraction = 0.01

excluded_genes_data = mc.tl.filter_data(full, var_masks=['~clean_gene'])[0]
excluded_umis_of_cells = mc.ut.get_o_numpy(excluded_genes_data, name='__x__', sum=True)
excluded_fraction_of_umis_of_cells = excluded_umis_of_cells / total_umis_of_cells

plot = sb.distplot(excluded_fraction_of_umis_of_cells)
plot.set(xlabel='Fraction of excluded gene UMIs', ylabel='Density', yticks=[])
plot.axvline(x=properly_sampled_max_excluded_genes_fraction, color='crimson')

too_excluded_cells_count = sum(excluded_fraction_of_umis_of_cells > properly_sampled_max_excluded_genes_fraction)

too_excluded_cells_percent = 100.0 * too_excluded_cells_count / len(total_umis_of_cells)

print(f"Will exclude %s (%.2f%%) cells with more than %.2f%% excluded gene UMIs"
      % (too_excluded_cells_count,
         too_excluded_cells_percent,
         100.0 * properly_sampled_max_excluded_genes_fraction))

In [None]:
%%time
mc.pl.analyze_clean_cells(
    full,
    properly_sampled_min_cell_total=properly_sampled_min_cell_total,
    properly_sampled_max_cell_total=properly_sampled_max_cell_total,
    properly_sampled_max_excluded_genes_fraction=properly_sampled_max_excluded_genes_fraction)

In [None]:
%%time
mc.pl.pick_clean_cells(full)

In [None]:
%%time
clean = mc.pl.extract_clean_data(full)

In [None]:
%%time
suspect_gene_names = [
    "A2m", "AA465934;AI450353", "AK021383;Prrc2c", "AK033756;Rab3il1", "AK087340;Eif3a", "AK156288;Tpd52", "AK158346;Snrpd2", "AK164737;Ell2", "AK165270;Rbm25", "AK196308;Tuba1b", "AK202516;P4hb", "AK204572;Eef1a1", "Acsl3;Utp14b", "Alad", "Aldoa", "Ankrd11", "Arl6ip1", "Atp5j", "Atrx", "Bst2", "Calr", "Cap1", "Cbx7", "Ccdc155", "Ccnb1", "Ccne1", "Cenpe", "Cenpf", "Chchd10", "Chd4;Mir7045", "Cox7b", "Cox7c", "Cox8a", "Cpox", "Csf2rb;Mir7676-2", "Ddx21", "Dek", "Dmkn", "Dsg2", "Eif5a", "Eif5b", "Erv3", "F11r", "Fblim1", "Fech", "Glrx5", "Gm12338", "Gm15772;Rpl26", "Gm1821", "Gprc5a", "Gpx1", "Gse1", "Hbb-b2", "Hdac6", "Hdgf", "Hist1h1a", "Hist1h1b", "Hist1h1c", "Hist1h1d", "Hist1h1e", "Hist1h2ae", "Hmmr", "Hsp90aa1", "Hsp90ab1", "Hsp90b1", "Hspa5", "Hspa8", "Hspe1", "Il1r2", "Isyna1", "Kif20b", "Kmt2a", "Ldha", "Ly6a", "Ly6c1", "Manf", "Mbnl1", "Mir6236", "Mir7079;Rpl13", "Mki67", "Msh6", "Naca", "Nasp", "Ncl", "Npm1", "Pdia6", "Pfn1", "Pkm", "Plekhf2", "Pnpo", "Prc1", "Prdx2", "Prmt1", "Prpf40a", "Psip1", "Psmb10", "Ptma", "Pttg1", "Rab15", "Rell1", "Rn45s", "Rpl14-ps1", "Rpl22l1", "Rpl23", "Rpl32", "Rpl37", "Rpl38", "Rpl41", "Rpl7", "Rplp0", "Rplp1", "Rps10", "Rps11", "Rps14", "Rps15", "Rps15a-ps6", "Rps18", "Rps2", "Rps20", "Rps21", "Rps25", "Rps26", "Rps27l", "Rps4l", "Rps5", "Rpsa", "Rrm2", "S100a4", "S100a8", "Slc14a1", "Slc16a3", "Slc6a12", "Smc4", "Smox", "Soat1", "Son", "Spint1", "Spns2", "Srrm2", 
#     "Ssx2ip", 
    "Tac2", "Tfrc", "Tjp2", "Tmem14c", "Top1", "Top2a", "Tpr", "Tubb5", "Ubb", "Ube2c", "Ung", "Hspb1", "Hspb8"
    
#     Hbb-bs, # 1e-3
    "Hbb-y",  # 2e-2
#     Hbb-b2, # 2e-4
    "Hba-a2", # 1e-2
    "Hba-x",  # 4e-2
    "Hbb-bh1",# 1e-1
]
suspect_gene_patterns = [ 'MCM[0-9]', 'SMC[0-9]', 'IFI.*', 'Hist1*.', 'HSP90.+' , 'COX.+']
# suspect_gene_patterns = ['MCM[0-9]', 'SMC[0-9]', 'IFI.*', 'Hist1*.', 'HSP*' , 'COX.+']
suspect_genes_mask = mc.tl.find_named_genes(clean, names=suspect_gene_names,
                                            patterns=suspect_gene_patterns)
suspect_gene_names = sorted(clean.var_names[suspect_genes_mask])

In [None]:
%%time
mc.pl.relate_genes(clean, random_seed=123456, 
                   genes_similarity_method="pearson")#, min_genes_of_modules=)
# method in ("pearson", "repeated_pearson", "logistics", "logistics_pearson")

In [None]:
module_of_genes = clean.var['related_genes_module']
suspect_gene_modules = np.unique(module_of_genes[suspect_genes_mask])
suspect_gene_modules = suspect_gene_modules[suspect_gene_modules >= 0]
print(suspect_gene_modules)

In [None]:
similarity_of_genes = mc.ut.get_vv_frame(clean, 'related_genes_similarity')
for gene_module in suspect_gene_modules:
    module_genes_mask = module_of_genes == gene_module
    similarity_of_module = similarity_of_genes.loc[module_genes_mask, module_genes_mask]
    similarity_of_module.index = \
    similarity_of_module.columns = [
        '(*) ' + name if name in suspect_gene_names else name
        for name in similarity_of_module.index
    ]
    ax = plt.axes()
    sb.heatmap(similarity_of_module, vmin=0, vmax=1, xticklabels=True, yticklabels=True, ax=ax, cmap="YlGnBu")
    ax.set_title(f'Gene Module {gene_module}')
#     plt.savefig('embexe_new_{}.png'.format(gene_module), bbox_inches='tight')
    plt.show()

# You have to manually select forbidden clusters based on the previous cell's output

In [None]:
%%time
forbidden_genes_mask = suspect_genes_mask

for gene_module in [14,37,60,69,89,105,122,123,127]:
    module_genes_mask = module_of_genes == gene_module
    forbidden_genes_mask |= module_genes_mask
# forbidden_genes_mask['HES1'] = False
forbidden_gene_names = sorted(clean.var_names[forbidden_genes_mask])
print(len(forbidden_gene_names))
print(' '.join(forbidden_gene_names))

In [None]:
max_parallel_piles = mc.pl.guess_max_parallel_piles(clean)
print(max_parallel_piles)
mc.pl.set_max_parallel_piles(max_parallel_piles)

In [None]:
%%time
mc.pl.divide_and_conquer_pipeline(clean,
                                  forbidden_gene_names=forbidden_gene_names,
                                  target_metacell_size=320000,
                                  random_seed=123456)

In [None]:
%%time
metacells = mc.pl.collect_metacells(clean, name='metacells')

In [None]:
metacells.obs_names

In [None]:
clean.obs.metacell >= 0
clean2 = mc.ut.slice(clean, obs=clean.obs.metacell >= 0)

In [None]:
min(clean2.obs.groupby('metacell').aggregate('count')['properly_sampled_cell'])

In [None]:
%%time
mc.pl.compute_umap_by_features(metacells, max_top_feature_genes=1000,
                               min_dist=2.0, random_seed=42)

In [None]:
umap_x = mc.ut.get_o_numpy(metacells, 'umap_x')
umap_y = mc.ut.get_o_numpy(metacells, 'umap_y')
plot = sb.scatterplot(x=umap_x, y=umap_y)

In [None]:
%%time
umap_edges = sp.coo_matrix(mc.ut.get_oo_proper(metacells, 'obs_outgoing_weights'))
min_long_edge_size = 4
sb.set()
plot = sb.scatterplot(x=umap_x, y=umap_y)
for (source_index, target_index, weight) \
        in zip(umap_edges.row, umap_edges.col, umap_edges.data):
    source_x = umap_x[source_index]
    target_x = umap_x[target_index]
    source_y = umap_y[source_index]
    target_y = umap_y[target_index]
    if hypot(target_x - source_x, target_y - source_y) >= min_long_edge_size:
        plt.plot([source_x, target_x], [source_y, target_y],
                 linewidth=weight * 2, color='indigo')
plt.show()

In [None]:
clean.write('clean_embexe2.h5ad')
metacells.write('metacells_embexe2.h5ad')