In [1]:
import anndata
import numpy as np
import pandas as pd
import tensorflow as tf

from pathlib import Path
from sklearn.model_selection import train_test_split

from pmvae.model import PMVAE
from pmvae.train import train
from pmvae.utils import load_annotations

In [2]:
data = anndata.read('../data/kang_count.h5ad')
data.varm['annotations'] = load_annotations(
    '../data/c2.cp.reactome.v4.0.symbols.gmt',
    data.var_names,
    min_genes=13
)


In [3]:
membership_mask = data.varm['annotations'].astype(bool).T
trainset, testset = train_test_split(
    data.X,
    test_size=0.25,
    shuffle=True,
    random_state=0,
    
)

batch_size = 256
trainset = tf.data.Dataset.from_tensor_slices(trainset)
trainset = trainset.shuffle(5 * batch_size).batch(batch_size)


In [4]:
model = PMVAE(
    membership_mask=membership_mask.values,
    module_latent_dim=4,
    hidden_layers=[12],
    add_auxiliary_module=True,
    beta=1e-05,
    kernel_initializer='he_uniform',
    bias_initializer='zero',
    activation='elu',
    terms=membership_mask.index
)

opt = tf.keras.optimizers.Adam(learning_rate=0.001)


In [5]:
# This should take ~1hr on GPU (first iter takes ~1m)
history = train(model, opt, trainset, testset, nepochs=1200)

100%|██████████| 1200/1200 [47:47<00:00,  2.39s/it]


In [6]:
outputs = model.call(data.X)

In [7]:
outpath = Path('../data/kang_recons.h5ad')

recons = anndata.AnnData(
    pd.DataFrame(
        outputs.global_recon.numpy(),
        index=data.obs_names,
        columns=data.var_names),
    obs=data.obs,
    varm=data.varm,
)

recons.obsm['codes'] = pd.DataFrame(
    outputs.z.numpy(),
    index=data.obs_names,
    columns=model.latent_space_names())

recons.write(outpath)

In [8]:
from sklearn.manifold import TSNE

def extract_pathway_cols(df, pathway):
    mask = df.columns.str.startswith(pathway + '-')
    return df.loc[:, mask]

def compute_tsnes(recons, pathways):
    for key in pathways:
        tsne = TSNE(n_components=2)
        codes = extract_pathway_cols(recons.obsm['codes'], key)
        tsne = pd.DataFrame(
            TSNE().fit_transform(codes.values),
            index=recons.obs_names,
            columns=[f'{key}-0', f'{key}-1'])
        yield tsne

pathways = [
    'REACTOME_INTERFERON_ALPHA_BETA_SIGNALING',
    'REACTOME_CYTOKINE_SIGNALING_IN_IMMUNE_SYSTEM',
    'REACTOME_TCR_SIGNALING',
    'REACTOME_CELL_CYCLE']

recons.obsm['pathway_tsnes'] = pd.concat(
    compute_tsnes(recons, pathways),
    axis=1)

recons.write(outpath)