# Process Velmeshev et al.

## Set up Env

In [None]:
import os
import pandas as pd
import numpy as np
import scanpy as sc
from scipy.sparse import csr_matrix

In [None]:
import liana as li

In [None]:
from prep_utils import filter_samples, filter_celltypes, map_gene_symbols

In [None]:
dataset = 'velmeshev'
groupby = 'cluster'
sample_key = 'individual'
condition_key = 'diagnosis'
batch_key = 'sex'

min_cells_per_sample = 700
sample_zcounts_max = 3
sample_zcounts_min = -2

# set filtering parameters
min_cells = 10 # min number of cells per cell type
min_samples = 5 # min number of samples that pass the threshold per cell type

## Preprocess

### Load data

In [None]:
adata = sc.read_h5ad(os.path.join('data', f"{dataset}.h5ad"))
adata

In [None]:
#TODO as param map_var that accepts a csv path
df = adata.var.reset_index()['index'].str.split('\\|', expand=True).rename(columns={0:'ensembl', 1:'genesymbol'})
adata.var = df.set_index('ensembl')
map_df = df.rename(columns={'ensembl':'alias', 'genesymbol':'gene'})


In [None]:
map_df.to_csv(os.path.join('data', "ensembl_to_symbol.csv"), index=False)

In [None]:
adata.obs[[sample_key, condition_key]].drop_duplicates().nunique()

### Convert to Genesymbols

In [None]:
df = adata.var.reset_index()['index'].str.split('\\|', expand=True).rename(columns={0:'ensembl', 1:'genesymbol'})
adata.var = df.set_index('ensembl')
map_df = df.rename(columns={'ensembl':'alias', 'genesymbol':'gene'})
map_df
adata = map_gene_symbols(adata, map_df)

In [None]:
adata

Filter samples, cell types, and genes

In [None]:
adata = filter_samples(adata, sample_key, condition_key, min_cells_per_sample, sample_zcounts_max, sample_zcounts_min)

In [None]:
## ^ Double check why are there more than 16?

In [None]:
adata = filter_celltypes(adata=adata, groupby=groupby, sample_key=sample_key, min_cells=min_cells, min_samples=min_samples)

In [None]:
# Remove genes expressed in few cells, normalize
sc.pp.filter_genes(adata, min_cells=30)

### Normalize

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

## Run LIANA

In [None]:
li.mt.rank_aggregate.by_sample(adata, groupby=groupby, use_raw=False, sample_key=sample_key, verbose=True, n_perms=None)

In [None]:
dataset

### Add Metadata & Write

In [None]:
adata.uns['sample_key'] = sample_key
adata.uns['batch_key'] = batch_key
adata.uns['condition_key'] = condition_key

In [None]:
assert np.isin(['sample_key', 'batch_key', 'condition_key'], adata.uns_keys()).all()

In [None]:
adata.write_h5ad(os.path.join('data', 'interim', f"{dataset}_processed.h5ad"))

## Classify

In [None]:
import scanpy as sc
from classify_utils import classifier_pipe

In [None]:
adata = sc.read_h5ad(os.path.join('data', 'interim', f"{dataset}_processed.h5ad"), backed='r')

In [None]:
classifier_pipe(adata, dataset)

In [None]:
adata.uns['auc']

In [None]:
adata.uns['tensor_res'].X['lr_means']