In [None]:
import importlib

import numpy as np
from scipy import sparse as ss
import pandas as pd
import anndata

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
from rich import print as rprint
from rich.traceback import install
install()

In [None]:
import pipeline
pipeline = importlib.reload(pipeline)

## Brain

In [None]:
DATA_DIR = '/home/tiankang/wusuowei/data/single_cell/MARS/scRNAseq_Benchmark_datasets/Inter-dataset/Brain/cached/'

In [None]:
for file_name in [
    '/home/tiankang/wusuowei/data/single_cell/MARS/scRNAseq_Benchmark_datasets/Inter-dataset/Brain/MouseALM_HumanMTG/MouseALM_HumanMTG.csv',
    '/home/tiankang/wusuowei/data/single_cell/MARS/scRNAseq_Benchmark_datasets/Inter-dataset/Brain/MouseV1_HumanMTG/MouseV1_HumanMTG.csv',
    '/home/tiankang/wusuowei/data/single_cell/MARS/scRNAseq_Benchmark_datasets/Inter-dataset/Brain/MouseV1_MouseALM/MouseV1_MouseALM.csv',
    '/home/tiankang/wusuowei/data/single_cell/MARS/scRNAseq_Benchmark_datasets/Inter-dataset/Brain/MouseV1_MouseALM_HumanMTG/MouseV1_MouseALM_HumanMTG.csv',
]:
    with open(file_name) as f:
        for line in f:
            print(len(line.split(',')))
            break

In [None]:
for NAME in ['HumanMTG', 'MouseALM', 'MouseV1']:
    df = pd.read_feather(DATA_DIR + f'{NAME}.feather')
    label_coarse = pd.read_csv(DATA_DIR + f'{NAME}_Label3.csv', squeeze=True)
    label_refined = pd.read_csv(DATA_DIR + f'{NAME}_Label34.csv', squeeze=True)
    # ---- plot label dist ----
    f, (ax1, ax2) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 3]}, figsize=(14, 7))
    sns.histplot(data=label_coarse, shrink=.8, ax=ax1)
    g = sns.histplot(data=label_refined, shrink=.8, ax=ax2)
    plt.xticks(
        rotation=45, 
        horizontalalignment='right',
    )
    plt.savefig(f'./explore/imgs/Brain_{NAME}_label_dist.jpg')
    # ---- plot qc ----
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7))
    df.sum(axis=1).plot(kind='hist', bins=30, ax=ax1)
    (df != 0).sum(axis=1).plot(kind='hist', bins=30, ax=ax2)
    plt.savefig(f'./explore/imgs/Brain_{NAME}_qc.jpg')
    # ---- UMAP ----
    X_normalized = pipeline.normalize(df.values, apply_qc=False)
    umap_object, X_umap = pipeline.umap(X_normalized, n_components=2)
    fig = pipeline.plot(X_umap, backend='plotly')
    fig.write_image(f'./explore/imgs/Brain_{NAME}_noqc_umap.jpg')
    fig = pipeline.plot(X_umap, labels=pipeline.label_to_idx(label_coarse)[0], backend='plotly')
    fig.write_image(f'./explore/imgs/Brain_{NAME}_noqc_umap_coarse.jpg')
    fig = pipeline.plot(X_umap, labels=pipeline.label_to_idx(label_refined)[0], backend='plotly')
    fig.write_image(f'./explore/imgs/Brain_{NAME}_noqc_umap_refined.jpg')


## CellBench

In [None]:
DATA_DIR = '/home/tiankang/wusuowei/data/single_cell/MARS/scRNAseq_Benchmark_datasets/Inter-dataset/CellBench/processed/'

In [None]:
for NAME in ('10x', 'celseq2'):
    df = pd.read_feather(DATA_DIR + f'{NAME}.feather')
    label = pd.read_csv(DATA_DIR + f'label_{NAME}.csv', squeeze=True)
    # ---- plot label dist ----
    f, ax = plt.subplots(figsize=(7, 7))
    g = sns.histplot(data=label, shrink=.8, ax=ax)
    plt.xticks(
        rotation=45, 
        horizontalalignment='right',
    )
    plt.savefig(f'./explore/imgs/CellBench_{NAME}_label_dist.jpg')
    # ---- plot qc ----
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7))
    df.sum(axis=1).plot(kind='hist', bins=30, ax=ax1)
    (df != 0).sum(axis=1).plot(kind='hist', bins=30, ax=ax2)
    plt.savefig(f'./explore/imgs/CellBench_{NAME}_qc.jpg')
    # ---- UMAP ----
    X_normalized = pipeline.normalize(df.values, apply_qc=False)
    umap_object, X_umap = pipeline.umap(X_normalized, n_components=2)
    fig = pipeline.plot(X_umap, backend='plotly')
    fig.write_image(f'./explore/imgs/CellBench_{NAME}_noqc_umap.jpg')
    fig = pipeline.plot(X_umap, labels=pipeline.label_to_idx(label)[0], backend='plotly')
    fig.write_image(f'./explore/imgs/CellBench_{NAME}_noqc_umap_coarse.jpg')

## Tabula Muris

In [None]:
tubula_muris_h5ad = anndata.read_h5ad('/home/tiankang/wusuowei/data/single_cell/MARS/tabula-muris-senis-facs_mars.h5ad')

In [None]:
tubula_muris_h5ad.X.shape

In [None]:
tubula_muris_h5ad.var

In [None]:
tubula_muris_h5ad.obs

In [None]:
tubula_muris_h5ad.obs.cell_ontology_class_reannotated.unique()

In [None]:
df = tubula_muris_h5ad.obs

In [None]:
df.cell_ontology_class_reannotated.isin((df.cell_ontology_class_reannotated.value_counts() > 200).index.tolist()[1:]).sum()

In [None]:
valid_data = tubula_muris_h5ad.X[np.random.choice(np.arange(len(df))[df.cell_ontology_class_reannotated.isin((df.cell_ontology_class_reannotated.value_counts() > 200).index.tolist()[1:])], size=15000, replace=False)]

In [None]:
ss.save_npz('/home/tiankang/wusuowei/data/single_cell/MARS/Tabula_Muris/val/count.npz', valid_data)

In [None]:
labels = df.cell_ontology_class_reannotated

In [None]:
for i in tubula_muris_h5ad.obs.columns:
    rprint(f'---- {i} ----')
    num_unique = tubula_muris_h5ad.obs[i].unique().shape[0]
    if num_unique > 200:
        print(f'{i} is skipped.')
        continue
    f, ax = plt.subplots(figsize=(20, 8))
    sns.histplot(tubula_muris_h5ad.obs[i], shrink=0.7, ax=ax)
    plt.xticks(
        rotation=90,
        horizontalalignment='right',
    )
    plt.show()

In [None]:
X = tubula_muris_h5ad.X

In [None]:
DATA_DIR = '/home/tiankang/wusuowei/data/single_cell/MARS/Tabula_Muris/'

In [None]:
X = tubula_muris_h5ad.X.copy()
ss.save_npz(DATA_DIR + 'count.npz', X)

In [None]:
tubula_muris_h5ad.obs.to_csv(DATA_DIR + 'label.csv')

In [None]:
pipeline = importlib.reload(pipeline)

In [None]:
X_normalized, labels = pipeline.normalize(
    X,
    tubula_muris_h5ad.obs.cell_ontology_class_reannotated,
    clip_q=99,
    gene_min_cells=50,
    gene_min_counts=100,
    cell_min_genes=100,
    cell_min_counts=1000,
    cell_max_counts=4_000_000,
    logic='mine',
    plot=True,
    preqc_path='./explore/imgs/Tubula_Muris_preqc.jpg',
    postqc_path='./explore/imgs/Tubula_Muris_postqc.jpg',
)

In [None]:
labels.isin(labels.value_counts()[labels.value_counts() > 200].index.tolist()[1:]).sum()

In [None]:
labels.value_counts()[labels.value_counts() <= 200].sum()

In [None]:
indices = np.random.choice(
    np.arange(X_normalized.shape[0])[
        labels.isin(
            labels.value_counts()[labels.value_counts() > 200].index.tolist()[1:]
        )
    ],
    replace=False,
    size=15000
)

In [None]:
mask = np.zeros(X_normalized.shape[0], dtype=bool)
mask[indices] = 1

In [None]:
ids, labels_to_ids_dict = pipeline.labels_to_ids(labels)

In [None]:
with open(DATA_DIR + 'label_ids.txt', 'w') as f:
    for k, v in labels_to_ids_dict.items():
        print(k, v, sep='\t', file=f)

In [None]:
np.save(DATA_DIR + 'val/labels.npy', ids[mask])
np.save(DATA_DIR + 'train/labels.npy', ids[~mask])

In [None]:
ss.save_npz(DATA_DIR + 'val/data.npz', X_normalized[mask])

ss.save_npz(DATA_DIR + 'train/data.npz', X_normalized[~mask])
np.save(DATA_DIR + 'train/labels.npy', labels[~mask])

In [None]:
f, ax = plt.subplots(figsize=(20, 8))
sns.histplot(labels, shrink=0.7, ax=ax)
plt.xticks(
    rotation=90,
    horizontalalignment='right',
)
None

In [None]:
np.save('./explore/cached/Tubula_Muris_umap.npy', X_umap)

In [None]:
umap_object, X_umap = pipeline.umap(X_normalized, n_components=3, n_neighbors=50)

In [None]:
X_umap = np.load('./explore/cached/Tubula_Muris_umap.npy')

In [None]:
fig = pipeline.plot(
    X_umap,
    labels=pipeline.label_to_idx(tubula_muris_h5ad.obs.cell_ontology_class_reannotated)[0],
    backend='plotly')
fig