In [2]:
import pmvae

In [1]:
import anndata
import numpy as np
import pandas as pd
import tensorflow as tf

from pathlib import Path
from sklearn.model_selection import train_test_split

from pmvae.model import PMVAE
from pmvae.train import train

  from pandas.core.index import RangeIndex


ModuleNotFoundError: No module named 'utils'

In [2]:
data = anndata.read('./data/kang_count.h5ad')

membership_mask = data.varm['I'].astype(bool).T
trainset, testset = train_test_split(data.X, shuffle=True, test_size=0.25)

batch_size = 256
trainset = tf.data.Dataset.from_tensor_slices(trainset)
trainset = trainset.shuffle(5 * batch_size).batch(batch_size)

In [10]:
model = PMVAE(
    membership_mask, 4, [12],
    add_auxiliary_module=True,
    beta=1e-05,
    kernel_initializer='he_uniform',
    bias_initializer='zero',
    activation='elu',
)

opt = tf.keras.optimizers.Adam(learning_rate=0.001)

In [4]:
history = train(model, opt, trainset, testset, nepochs=1200)

199.64026670716703
0.21853356 0.21849917


Unnamed: 0,train-loss,train-recon,train-kl,train-local,train-global,test-loss,test-recon,test-kl,test-local,test-global
0,0.230926,0.230893,3.313144,0.115915,0.114978,0.218534,0.218499,3.438679,0.109987,0.108512


In [11]:
outputs = model(data.X)

In [48]:
def embeddings_to_df(codes, terms, index, add_auxiliary=True):
    terms = list(terms)
    if add_auxiliary:
        terms.append('AUXILIARY')
    terms = pd.Series(terms)
    
    latent_dim_per_pathway = codes.shape[1] // terms.size
    term_index = np.tile(range(latent_dim_per_pathway), terms.size).astype(str)
    terms = terms.repeat(latent_dim_per_pathway) + '-' + term_index.astype(str)
    
    return pd.DataFrame(codes, columns=terms.values, index=index)

outdir = Path('./results')
outdir.mkdir(exist_ok=True, parents=True)

recons = anndata.AnnData(
    outputs.global_recon.numpy(),
    obs=data.obs,
    uns=data.uns,
    varm=data.varm,
)

recons.obsm['codes'] = embeddings_to_df(
    outputs.mu.numpy(),
    data.uns['terms'],
    data.obs_names)

recons.obsm['logvar'] = embeddings_to_df(
    outputs.logvar.numpy(),
    data.uns['terms'],
    data.obs_names)

data.write(outdir/'recons.h5ad')