In [1]:
from inference import InferencePertData, GEARSInference
import os
import torch
import random
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# download model and data
dataset_name = 'norman'
save_to = 'models'
os.system(f'aws s3 cp --recursive s3://syntensor-labs-data/inference/GEARS/{dataset_name} models/{dataset_name}/')

download: s3://syntensor-labs-data/inference/GEARS/norman/config.pkl to models/norman/config.pkl
download: s3://syntensor-labs-data/inference/GEARS/norman/test_preds.pkl to models/norman/test_preds.pkl
download: s3://syntensor-labs-data/inference/GEARS/norman/model.pt to models/norman/model.pt
download: s3://syntensor-labs-data/inference/GEARS/norman/infer_pertdata.pkl to models/norman/infer_pertdata.pkl


0

In [3]:
model = GEARSInference(f'models/{dataset_name}/infer_pertdata.pkl')
model.load_pretrained(f'models/{dataset_name}/')
possible_perturbations = model.pert_list

In [4]:
example_request_single_perts = [[x] for x in random.sample(possible_perturbations, 10)]
example_request_double_perts = [[x, y] for x,y in zip(random.sample(possible_perturbations, 10), random.sample(possible_perturbations, 10))]

In [5]:
example_request_single_perts

[['FIS1'],
 ['SOHLH1'],
 ['NUMBL'],
 ['PPM1G'],
 ['R3HDM4'],
 ['PAXIP1'],
 ['ZNF266'],
 ['LITAF'],
 ['COBL'],
 ['ZNF503']]

In [6]:
# to get raw perturbed expression results
results = model.predict(example_request_single_perts)
print(results)

{'FIS1': array([1.0906649e-03, 1.1652883e-02, 3.1155536e-02, ..., 3.6954989e+00,
       4.3421159e-03, 5.0031528e-04], dtype=float32), 'SOHLH1': array([1.0927292e-03, 1.5304216e-02, 4.3143187e-02, ..., 3.6415303e+00,
       5.3680832e-03, 6.2373752e-04], dtype=float32), 'NUMBL': array([9.9577324e-04, 1.5301293e-02, 3.6256541e-02, ..., 3.6689849e+00,
       4.0848227e-03, 4.6217450e-04], dtype=float32), 'PPM1G': array([9.8308909e-04, 1.4764603e-02, 3.6246613e-02, ..., 3.6642740e+00,
       3.8114686e-03, 4.2636736e-04], dtype=float32), 'R3HDM4': array([1.0831652e-03, 1.5455273e-02, 3.6536362e-02, ..., 3.6899271e+00,
       4.2636427e-03, 4.9359468e-04], dtype=float32), 'PAXIP1': array([1.0416181e-03, 1.3541746e-02, 3.7485082e-02, ..., 3.6581514e+00,
       4.2920937e-03, 4.8734731e-04], dtype=float32), 'ZNF266': array([1.0831652e-03, 1.7165283e-02, 3.8888853e-02, ..., 3.7212150e+00,
       4.2636427e-03, 4.9359468e-04], dtype=float32), 'LITAF': array([2.0675422e-03, 2.1229088e-02, 4.351

In [7]:
# to convert to a dict with ensembl_id gene keys
results_dict = {k:dict(zip(model.gene_list, v)) for k,v in results.items()}

In [8]:
# to convert compare to control
diff = {k:v-model.ctrl_mean for k,v in results.items()}
logfc = {k:np.log2(v/model.ctrl_mean) for k,v in results.items()}