In [1]:
import anndata
import numpy as np
import pandas as pd

import torch

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
from utils import load_annotations

In [3]:
from sklearn.model_selection import train_test_split

# load data

In [4]:
data = anndata.read('data/kang_count.h5ad')

In [5]:
pathway_ann_matrix = load_annotations(
    'data/c2.cp.reactome.v7.4.symbols.gmt',
    data.var_names,
    min_genes=13
)

In [6]:
[x for x in pathway_ann_matrix.columns if 'G2_M_TRANSITION' in x or 'PLK1' in x]

['REACTOME_REGULATION_OF_PLK1_ACTIVITY_AT_G2_M_TRANSITION']

In [7]:
pathway_ann_matrix[pathway_ann_matrix['REACTOME_REGULATION_OF_PLK1_ACTIVITY_AT_G2_M_TRANSITION']][['REACTOME_REGULATION_OF_PLK1_ACTIVITY_AT_G2_M_TRANSITION']]

Unnamed: 0_level_0,REACTOME_REGULATION_OF_PLK1_ACTIVITY_AT_G2_M_TRANSITION
index,Unnamed: 1_level_1
PPP1CB,True
CLASP1,True
TUBA4A,True
CCNB1,True
TUBB,True
CUL1,True
CDK5RAP2,True
TUBB4B,True
CDK1,True
ACTR1A,True


In [8]:
true_pathways_list = [x for x in pathway_ann_matrix.columns if 'G2_M_TRANSITION' in x or 'PLK1' in x]
drop_pathway_ann_matrix = pathway_ann_matrix.loc[:,~pathway_ann_matrix.columns.isin(true_pathways_list)]

In [9]:
data.varm['annotations'] = drop_pathway_ann_matrix

In [10]:
drop_pathway_ann_matrix.iloc[:,drop_pathway_ann_matrix.loc['IFITM3',:].values == True]

Unnamed: 0_level_0,REACTOME_CYTOKINE_SIGNALING_IN_IMMUNE_SYSTEM,REACTOME_INTERFERON_ALPHA_BETA_SIGNALING,REACTOME_INTERFERON_SIGNALING
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ISG15,True,True,True
MIB2,False,False,False
PRKCZ,False,False,False
KCNAB2,False,False,False
CTNNBIP1,False,False,False
...,...,...,...
CYP19A1,False,False,False
RAP1GAP2,False,False,False
SSTR2,False,False,False
BIRC5,True,False,False


In [11]:
membership_mask = data.varm['annotations'].astype(bool).T
X_train, X_test = train_test_split(
    data.X,
    test_size=0.25,
    shuffle=True,
    random_state=0,
    
)

# initialize model

In [12]:
from models import pmVAEModel

In [13]:
kangVAE = pmVAEModel(
    membership_mask.values,
    [12],
    4,
    beta=1e-05,
    terms=membership_mask.index,
    add_auxiliary_module=True
)

In [14]:
kangVAE.model

pmVAE(
  (encoder_net): pmEncoder(
    (encoder_dense_1): CustomizedLinear(input_features=979, output_features=2400, bias=True)
    (encoder_norm_1): BatchNorm1d(2400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (encoder_elu_1): ELU(alpha=1.0, inplace=True)
    (encoder_dense_2): CustomizedLinear(input_features=2400, output_features=1600, bias=True)
    (encoder_norm_2): BatchNorm1d(1600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (decoder_net): pmDecoder(
    (decoder_dense_1): CustomizedLinear(input_features=800, output_features=2400, bias=True)
    (decoder_norm_1): BatchNorm1d(2400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (decoder_elu_1): ELU(alpha=1.0, inplace=True)
  )
  (merge_layer): CustomizedLinear(input_features=2400, output_features=979, bias=False)
)

# train model

In [None]:
kangVAE.train(train_ds, test_ds, checkpoint_path='pmvae_dropG2M_checkpoint.pkl')

# explain model

In [15]:
kangVAE.load_checkpoint('saved_models/pmvae_dropG2M_checkpoint.pkl.best_loss')

In [16]:
kangVAE.set_gpu(False)

In [17]:
len(kangVAE.latent_space_names())

800

In [18]:
kangVAE.latent_space_names().index('AUXILIARY-0')

796

In [19]:
kangVAE.latent_space_names()[-4]

'AUXILIARY-0'

In [20]:
kangVAE.latent_space_names()[-3]

'AUXILIARY-1'

In [21]:
kangVAE.latent_space_names()[-2]

'AUXILIARY-2'

In [22]:
kangVAE.latent_space_names()[-1]

'AUXILIARY-3'

In [23]:
def model_latent_wrapper(x):
    outs = kangVAE.model(x)
    z = outs.mu
    return z[:,-4].reshape(-1,1) # which to explain

In [24]:
from pathexplainer import PathExplainerTorch

In [25]:
input_data = torch.tensor(data.X)
input_data.requires_grad = True
baseline_data = torch.zeros(data.X.shape[1])
baseline_data.requires_grad = True

In [26]:
explainer = PathExplainerTorch(model_latent_wrapper)
attributions = explainer.attributions(input_data,
                                      baseline=baseline_data,
                                      num_samples=200,
                                      use_expectation=False)

In [27]:
np_attribs = attributions.detach().numpy()

In [28]:
top = pd.DataFrame(index=membership_mask.columns)
top['means'] = np.abs(np_attribs).mean(0)
top['stds'] = np.abs(np_attribs).std(0)


In [29]:
top.sort_values('means',ascending=False)

Unnamed: 0_level_0,means,stds
index,Unnamed: 1_level_1,Unnamed: 2_level_1
H2AFZ,1.558621,0.690636
IL8,0.588597,0.379918
PLA2G7,0.433617,0.340465
SSB,0.398044,0.208317
HIST1H2AC,0.234484,0.173549
...,...,...
IFNB1,0.000011,0.000189
PELI3,0.000010,0.000337
AURKB,0.000010,0.000136
SRGAP3,0.000010,0.000202


In [30]:
top.to_csv('kang_remove_g/aux_0.csv')