In [1]:
import scanpy as sc
import pandas as pd
import numpy as np
import torch
import scarches as sca

In [2]:
adata = sc.read('publ_pancreas.h5ad') # counts in .X

In [3]:
adata.obs.study.value_counts()

NOD_elimination    54329
STZ                49545
embryo             37561
spikein_drug       33331
NOD                 2690
Name: study, dtype: int64

In [4]:
to_query = "spikein_drug" # some random study just to put something
mask_query = adata.obs.study == to_query
query = adata[mask_query].copy()
ref = adata[~mask_query].copy()

In [5]:
del adata

In [8]:
df = pd.read_excel('marker_genes_collection.xlsx', sheet_name='Collection table')
markers = {k: [g.upper() for g in v] for k, v in df.groupby("cell_type")["gene_name"]}

In [9]:
ref.var_names = ref.var['gene_symbol_original_matched'].astype(str)

In [10]:
sca.add_annotations(ref, 'c2.cp.reactome.v4.0.symbols.gmt', min_genes=20, max_genes=200)

In [11]:
# add marker genes to annotations
var_names = ref.var_names.str.upper()

I = [[int(gene in ct) for ct in markers.values()] for gene in var_names]
I = np.asarray(I, dtype='int32')

In [12]:
I.sum(0)

array([ 1,  1,  0,  1,  1,  0,  0,  0,  1,  0,  0,  2,  3,  0,  1,  0,  0,
        0,  4,  5,  1,  0,  4,  0,  0,  5,  1,  1,  2, 13,  5,  1,  4, 12,
        9,  0, 11,  1,  3,  0, 30,  3])

In [13]:
nz = I.sum(0) > 0
ref.uns['terms'] += [ct for i, ct in enumerate(markers.keys()) if nz[i]]
I = I[:, nz]

In [14]:
I.sum(0)

array([ 1,  1,  1,  1,  1,  2,  3,  1,  4,  5,  1,  4,  5,  1,  1,  2, 13,
        5,  1,  4, 12,  9, 11,  1,  3, 30,  3])

In [15]:
ref.varm['I'] = np.concatenate((ref.varm['I'], I), axis=1)

In [16]:
# rm genes not in annotations
ref._inplace_subset_var(ref.varm['I'].sum(1)>0)

In [17]:
ref.shape

(144125, 3631)

In [38]:
intr_cvae = sca.models.TRVAE(
    adata=ref,
    condition_key='study',
    hidden_layer_sizes=[1000, 600, 600],
    use_mmd=False,
    recon_loss='nb',
    mask=ref.varm['I'].T,
    use_decoder_relu=False,
    mmd_instead_kl=False
)


INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 3631 1000 4
	Hidden Layer 1 in/out: 1000 600
	Hidden Layer 2 in/out: 600 600
	Mean/Var Layer in/out: 600 341
Decoder Architecture:
	Masked linear layer in, out and cond:  341 3631 4


In [39]:
#OMEGA = ref.varm['I'].sum(0).astype('float32')
#OMEGA /= OMEGA.max()
#OMEGA *= 10
#OMEGA = torch.tensor(OMEGA)

In [None]:
# train reference

ALPHA = 0.8
OMEGA = None

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

intr_cvae.train(
    n_epochs=500, 
    alpha_epoch_anneal=200, 
    alpha=ALPHA, 
    omega=OMEGA,
    alpha_kl=0.003,
    weight_decay=0., 
    early_stopping_kwargs=early_stopping_kwargs,
    use_early_stopping=True,
    seed=2021
)

In [None]:
q_intr_cvae = sca.models.TRVAE.load_query_data(query, intr_cvae)

In [None]:
q_intr_cvae.train(n_epochs=500, alpha_epoch_anneal=200, weight_decay=0., alpha_kl=0.003, seed=2021)