In [1]:
import anndata
import numpy as np
import pandas as pd
import tensorflow as tf

from pathlib import Path
from sklearn.model_selection import train_test_split

from pmvae.model import PMVAE
from pmvae.train import train

In [2]:
def remove_unneeded_keys_kang(data):
    data.obs = data.obs[['condition', 'cell_type']]
    data.uns = dict()
    data.obsm = None
    data.varm = None
    return

def add_annotations_to_varm(gmt, data, min_genes=13):
    genesets = load_gmt_genesets(gmt, data.var_names, min_genes)
    annotations = pd.DataFrame(False, index=data.var_names, columns=genesets.keys())
    for key, genes in genesets.items():
        annotations.loc[genes, key] = True
    
    data.varm['annotations'] = annotations
    return

def load_gmt_genesets(path, symbols=None, min_genes=10):
    lut = dict()
    for line in open(path, 'r'):
        key, _, *genes = line.strip().split()
        if symbols is not None:
            genes = symbols.intersection(genes).tolist()
        if len(genes) < min_genes:
            continue
        lut[key] = genes
    return lut


In [9]:
data = anndata.read('../data/kang_count.h5ad')
remove_unneeded_keys_kang(data)
add_annotations_to_varm('../data/c2.cp.reactome.v4.0.symbols.gmt', data)



In [14]:
# data = anndata.read('../data/kang_count.h5ad')

membership_mask = data.varm['annotations'].values.astype(bool).T
trainset, testset = train_test_split(data.X, shuffle=True, test_size=0.25)

batch_size = 256
trainset = tf.data.Dataset.from_tensor_slices(trainset)
trainset = trainset.shuffle(5 * batch_size).batch(batch_size)

In [15]:
model = PMVAE(
    membership_mask, 4, [12],
    add_auxiliary_module=True,
    beta=1e-05,
    kernel_initializer='he_uniform',
    bias_initializer='zero',
    activation='elu',
)

opt = tf.keras.optimizers.Adam(learning_rate=0.001)

In [16]:
history = train(model, opt, trainset, testset, nepochs=1200)

KeyboardInterrupt: 

In [11]:
outputs = model(data.X)

In [48]:
def embeddings_to_df(codes, terms, index, add_auxiliary=True):
    terms = list(terms)
    if add_auxiliary:
        terms.append('AUXILIARY')
    terms = pd.Series(terms)
    
    latent_dim_per_pathway = codes.shape[1] // terms.size
    term_index = np.tile(range(latent_dim_per_pathway), terms.size).astype(str)
    terms = terms.repeat(latent_dim_per_pathway) + '-' + term_index.astype(str)
    
    return pd.DataFrame(codes, columns=terms.values, index=index)

outdir = Path('./results')
outdir.mkdir(exist_ok=True, parents=True)

recons = anndata.AnnData(
    outputs.global_recon.numpy(),
    obs=data.obs,
    uns=data.uns,
    varm=data.varm,
)

recons.obsm['codes'] = embeddings_to_df(
    outputs.mu.numpy(),
    data.uns['terms'],
    data.obs_names)

recons.obsm['logvar'] = embeddings_to_df(
    outputs.logvar.numpy(),
    data.uns['terms'],
    data.obs_names)

data.write(outdir/'recons.h5ad')