In [1]:
import numpy as np
import torch
import scanpy as sc
from data import PertDataloader, Network


model_name = 'GNN_Disentangle_GAT_string_20.0_64_2_l3_Norman2019_gamma2_gene_emb'
args = np.load('./saved_args/'+model_name+'.npy', allow_pickle = True).item()

In [2]:
if args['network_name'] == 'string':
    args['network_path'] = '/dfs/project/perturb-gnn/graphs/STRING_full_9606.csv'

if args['dataset'] == 'Norman2019':
    data_path = '/dfs/project/perturb-gnn/datasets/Norman2019_hvg+perts.h5ad'

adata = sc.read_h5ad(data_path)
if 'gene_symbols' not in adata.var.columns.values:
    adata.var['gene_symbols'] = adata.var['gene_name']
gene_list = [f for f in adata.var.gene_symbols.values]
# Set up message passing network
network = Network(fname=args['network_path'], gene_list=args['gene_list'],
                  percentile=args['top_edge_percent'])

# Pertrubation dataloader
pertdl = PertDataloader(adata, network.G, network.weights, args)

There are 101013 edges in the PPI.
Creating pyg object for each cell in the data...
Local copy of pyg dataset is detected. Loading...
Loading splits...
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:9
combo_seen1:52
combo_seen2:18
unseen_single:37
Creating dataloaders....
Dataloaders created...


In [75]:
node_map_inv = {j:i for i,j in pertdl.node_map.items()}

In [71]:
network.G

<networkx.classes.digraph.DiGraph at 0x7fb757cb0eb0>

In [6]:
self = torch.load('./saved_models/'+ model_name)

In [9]:
for batch in pertdl.loaders['train_loader']:
    break

In [10]:
batch.to(args['device'])
data = batch
x, edge_index, edge_attr, batch = data.x, data.edge_index, \
                                          data.edge_attr, data.batch

In [13]:
pert = x[:, 1].reshape(-1,1)
pert_emb = self.pert_w(pert)


gene_base = x[:, 0].reshape(-1,1)
base_emb = self.gene_basal_w(gene_base)

emb = self.emb(torch.LongTensor(list(range(self.num_genes))).repeat(self.args['batch_size'], ).to(self.args['device']))
base_emb = torch.cat((emb, base_emb), axis = 1)
base_emb = self.emb_trans(base_emb)

In [14]:
base_emb.shape

torch.Size([161440, 64])

In [16]:
import torch.nn as nn
pert_emb_trans = nn.Linear(64, 64)

In [18]:
pert.shape

torch.Size([161440, 1])

In [32]:
pert_index = torch.where(pert.reshape(*data.y.shape) == 1)

In [35]:
emb_one_set = self.emb(torch.LongTensor(list(range(self.num_genes))).to(self.args['device']))

In [36]:
pert_emb_trans.to(args['device'])
pert_global_emb = pert_emb_trans(emb_one_set)

In [37]:
pert_global_emb.shape

torch.Size([5045, 64])

In [38]:
batch_pert_index = pert_index[0]
pert_global_emb_batch = pert_global_emb[pert_index[1]]

In [46]:
base_emb = base_emb.reshape(self.args['batch_size'], self.num_genes, -1)

In [67]:
from model import MLP
hidden_size = 64
pert_lambda_pred = MLP([hidden_size, hidden_size, 1], last_layer_act='ReLU')

In [68]:
pert_lambda_pred.to(args['device'])

MLP(
  (relu): ReLU()
  (network): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [69]:
pert_emb_lambda = pert_lambda_pred(pert_global_emb[pert_index[1]])

In [84]:
node_map_inv[pert_index[1][4].item()]

'LHX1'

In [78]:
pert_index[1][i].item()

4115

In [70]:
for i, j in enumerate(batch_pert_index):
    base_emb[j] += pert_emb_lambda[i] * pert_global_emb[pert_index[1][i]]

In [52]:
base_emb.reshape(self.args['batch_size'] * self.num_genes, -1)

tensor([[-6.7838e-02,  8.4169e-02, -5.8724e-02,  ...,  9.5274e-02,
         -3.1291e-02,  6.7311e-02],
        [-6.7334e-02,  8.4227e-02, -5.9290e-02,  ...,  9.6037e-02,
         -3.1273e-02,  6.7037e-02],
        [-6.5092e-02,  8.4487e-02, -6.1930e-02,  ...,  9.9567e-02,
         -3.1193e-02,  6.5812e-02],
        ...,
        [ 6.6734e-04,  1.1483e-01, -6.6853e-01,  ...,  5.0136e-01,
          1.5226e-02, -5.3027e-02],
        [-6.0340e-02,  1.4597e-01, -4.7651e-02,  ...,  1.7981e-01,
         -4.2171e-02,  1.5695e-01],
        [-6.0688e-02,  1.4593e-01, -4.7246e-02,  ...,  1.7927e-01,
         -4.2184e-02,  1.5714e-01]], device='cuda:5', grad_fn=<ViewBackward>)

In [98]:
np.random.binomial(1, 0.75, 1)[0]

0

In [46]:
df_eval[df_eval.group == 'combo_seen2']

Unnamed: 0,mse_de_cpa,mse_de_gnn,group
AHR+KLF1,0.357695,0.368303,combo_seen2
CEBPE+RUNX1T1,0.100133,0.339434,combo_seen2
CNN1+MAPK1,0.078004,0.113659,combo_seen2
CNN1+UBASH3A,0.783318,0.708051,combo_seen2
ETS2+CNN1,0.133394,0.093787,combo_seen2
ETS2+IKZF3,0.310006,0.968617,combo_seen2
ETS2+MAPK1,0.355581,0.954693,combo_seen2
FOSB+IKZF3,0.356463,0.50704,combo_seen2
FOSB+UBASH3B,0.050452,0.0694,combo_seen2
FOXA1+HOXB9,0.180823,0.211096,combo_seen2


In [47]:
df_eval.to_csv('./cpa_gnn_comparison.csv', index = False)