In [3]:
import numpy as np
import torch
import scanpy as sc
from data import PertDataloader, Network

import warnings
warnings.filterwarnings("ignore")

name2path = {
    'GNN_Disentangle-L2': 'GNN_Disentangle_GAT_string_20.0_64_2_l2_Norman2019_gene_emb_pert_emb_constant_sim_gnn',
    'GNN_Disentangle_Sim': 'GNN_Disentangle_GAT_string_20.0_64_2_l3_Norman2019_gamma2_gene_emb_pert_emb_constant_sim_gnn',
    'GNN_Disentangle_Sim_No_Gene': 'GNN_Disentangle_sim_gnn', 
    'No-Perturb': 'No_Perturb_GAT_string_20.0_64_2_l3_Norman2019_gamma2_gene_emb_pert_emb_constant_sim_gnn',
    'No-GNN': 'best_no_gnn'
}

name = 'No-GNN'

model_name = name2path[name]
args = np.load('./saved_args/'+model_name+'.npy', allow_pickle = True).item()
args['device'] = 'cuda:5'

In [4]:
if args['network_name'] == 'string':
    args['network_path'] = '/dfs/project/perturb-gnn/graphs/STRING_full_9606.csv'

if args['dataset'] == 'Norman2019':
    data_path = '/dfs/project/perturb-gnn/datasets/Norman2019/Norman2019_hvg+perts_more_de.h5ad'

adata = sc.read_h5ad(data_path)
if 'gene_symbols' not in adata.var.columns.values:
    adata.var['gene_symbols'] = adata.var['gene_name']
gene_list = [f for f in adata.var.gene_symbols.values]
# Set up message passing network
network = Network(fname=args['network_path'], gene_list=args['gene_list'],
                  percentile=args['top_edge_percent'])

# Pertrubation dataloader
pertdl = PertDataloader(adata, network.G, network.weights, args)

There are 24886 edges in the PPI.
Creating pyg object for each cell in the data...
Local copy of pyg dataset is detected. Loading...
Loading splits...
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:9
combo_seen1:52
combo_seen2:18
unseen_single:37
Creating dataloaders....
Dataloaders created...


In [5]:
model = torch.load('./saved_models/' + model_name)
from inference import evaluate, compute_metrics, deeper_analysis, GI_subgroup

test_res = evaluate(pertdl.loaders['test_loader'],
                        pertdl.loaders['edge_index'],
                        pertdl.loaders['edge_attr'], model, args)

test_metrics, test_pert_res = compute_metrics(test_res)

In [8]:
import pickle
metrics = ['mse', 'mae', 'spearman', 'pearson', 'r2']
subgroup_path = './splits/Norman2019_simulation_1_0.1_subgroup.pkl'
subgroup = pickle.load(open(subgroup_path, "rb"))
        
subgroup_analysis = {}
for name in subgroup['test_subgroup'].keys():
    subgroup_analysis[name] = {}
    for m in list(list(test_pert_res.values())[0].keys()):
        subgroup_analysis[name][m] = []

for name, pert_list in subgroup['test_subgroup'].items():
    for pert in pert_list:
        for m, res in test_pert_res[pert].items():
            subgroup_analysis[name][m].append(res)

for name, result in subgroup_analysis.items():
    for m in result.keys():
        subgroup_analysis[name][m] = np.mean(subgroup_analysis[name][m])
        print('test_' + name + '_' + m + ': ' + str(subgroup_analysis[name][m]))

test_combo_seen0_mse: 0.0149425315
test_combo_seen0_mae: 0.08616412
test_combo_seen0_spearman: 0.819335605624276
test_combo_seen0_pearson: 0.9708977858783006
test_combo_seen0_r2: 0.901788848892149
test_combo_seen0_mse_de: 0.18566363
test_combo_seen0_mae_de: 0.352761
test_combo_seen0_spearman_de: 0.7186594257562022
test_combo_seen0_pearson_de: 0.8483505186854559
test_combo_seen0_r2_de: -0.8162012349942849
test_combo_seen1_mse: 0.016289996
test_combo_seen1_mae: 0.09451111
test_combo_seen1_spearman: 0.8268785803302016
test_combo_seen1_pearson: 0.9729620212236851
test_combo_seen1_r2: 0.8926126228112553
test_combo_seen1_mse_de: 0.20140103
test_combo_seen1_mae_de: 0.3661127
test_combo_seen1_spearman_de: 0.785525628950841
test_combo_seen1_pearson_de: 0.8540952321097159
test_combo_seen1_r2_de: -2.2421149073586397
test_combo_seen2_mse: 0.01899721
test_combo_seen2_mae: 0.108064085
test_combo_seen2_spearman: 0.847178200541306
test_combo_seen2_pearson: 0.9747761116353681
test_combo_seen2_r2: 0.870

In [9]:
out = deeper_analysis(adata, test_res)
GI_out = GI_subgroup(out)