In [1]:
import sys
sys.path.append('../')

from pertnet import PertData, PertNet

Load data. We use norman as an example.

In [2]:
pert_data = PertData('./data')
pert_data.load(data_name = 'norman')
pert_data.prepare_split(split = 'simulation', seed = 1)
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128)

Found local copy...
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:9
combo_seen1:52
combo_seen2:18
unseen_single:37
Done!
Creating dataloaders....
Done!


Create a model object; if you use [wandb](https://wandb.ai), you can easily track model training and evaluation by setting `weight_bias_track` to true, and specify the `proj_name` and `exp_name` that you like.

In [3]:
pertnet_model = PertNet(pert_data, device = 'cuda:7', 
                        weight_bias_track = False, 
                        proj_name = 'pertnet', 
                        exp_name = 'pertnet')
pertnet_model.model_initialize(hidden_size = 64)

You can find available tunable parameters in model_initialize via

In [4]:
pertnet_model.tunable_parameters()

{'hidden_size': 'hidden dimension, default 64',
 'num_go_gnn_layers': 'number of GNN layers for GO graph, default 1',
 'num_gene_gnn_layers': 'number of GNN layers for co-expression gene graph, default 1',
 'decoder_hidden_size': 'hidden dimension for gene-specific decoder, default 16',
 'num_similar_genes_go_graph': 'number of maximum similar K genes in the GO graph, default 20',
 'num_similar_genes_co_express_graph': 'number of maximum similar K genes in the co expression graph, default 20',
 'coexpress_threshold': 'pearson correlation threshold when constructing coexpression graph, default 0.4',
 'uncertainty': 'whether or not to turn on uncertainty mode, default False',
 'uncertainty_reg': 'regularization term to balance uncertainty loss and prediction loss, default 1',
 'direction_lambda': 'regularization term to balance direction loss and prediction loss, default 1'}

Train your model:

Note: For the sake of demo, we set epoch size to 1. To get full model, set `epochs = 20`.

In [5]:
pertnet_model.train(epochs = 1, lr = 1e-3)

Start Training...
Epoch 1 Step 1 Train Loss: 0.5698
Epoch 1 Step 51 Train Loss: 0.4824
Epoch 1 Step 101 Train Loss: 0.4848
Epoch 1 Step 151 Train Loss: 0.4174
Epoch 1 Step 201 Train Loss: 0.5737
Epoch 1 Step 251 Train Loss: 0.4751
Epoch 1 Step 301 Train Loss: 0.4553
Epoch 1 Step 351 Train Loss: 0.4441
Epoch 1 Step 401 Train Loss: 0.5072
Epoch 1 Step 451 Train Loss: 0.4946
Epoch 1 Step 501 Train Loss: 0.3759
Epoch 1 Step 551 Train Loss: 0.5398
Epoch 1 Step 601 Train Loss: 0.4312
Epoch 1 Step 651 Train Loss: 0.3959
Epoch 1 Step 701 Train Loss: 0.4093
Epoch 1 Step 751 Train Loss: 0.4570
Epoch 1 Step 801 Train Loss: 0.5324
Epoch 1 Step 851 Train Loss: 0.4404
Epoch 1 Step 901 Train Loss: 0.3997
Epoch 1 Step 951 Train Loss: 0.3840
Epoch 1 Step 1001 Train Loss: 0.4515
Epoch 1 Step 1051 Train Loss: 0.4805
Epoch 1 Step 1101 Train Loss: 0.4331
Epoch 1 Step 1151 Train Loss: 0.4536
Epoch 1 Step 1201 Train Loss: 0.4719
Epoch 1 Step 1251 Train Loss: 0.4553
Epoch 1 Step 1301 Train Loss: 0.5108
Epoch 

Save and load pretrained models:

In [6]:
pertnet_model.save_model('test_model')
pertnet_model.load_pretrained('test_model')

Make prediction for new perturbation:

In [7]:
pertnet_model.predict([['FEV'], ['FEV', 'SAMD11']])

{'FEV': array([-2.3899270e-08,  3.0855382e-02,  7.6131426e-02, ...,
         3.6461372e+00,  7.0264195e-03, -4.6351711e-32], dtype=float32),
 'FEV_SAMD11': array([-2.3899270e-08,  2.9599186e-02,  7.5969048e-02, ...,
         3.6660352e+00,  2.4150661e-03, -4.6351711e-32], dtype=float32)}

Gene list can be found here:

In [8]:
pertnet_model.gene_list

['RP11-34P13.8',
 'RP11-54O7.3',
 'SAMD11',
 'PERM1',
 'HES4',
 'ISG15',
 'RP11-54O7.18',
 'RNF223',
 'LINC01342',
 'TTLL10-AS1',
 'TNFRSF18',
 'TNFRSF4',
 'TAS1R3',
 'ANKRD65',
 'MMP23B',
 'RP11-345P4.7',
 'CALML6',
 'RP5-892K4.1',
 'RP11-181G12.4',
 'PLCH2',
 'RP3-395M20.12',
 'RP3-395M20.8',
 'TNFRSF14',
 'TTC34',
 'TP73',
 'SMIM1',
 'RP1-286D6.5',
 'CHD5',
 'LINC00337',
 'GPR153',
 'HES2',
 'ESPN',
 'TAS1R1',
 'RP11-338N10.1',
 'TNFRSF9',
 'RP5-1115A15.1',
 'RP4-633I8.4',
 'ENO1',
 'CA6',
 'GPR157',
 'MIR34AHG',
 'RP3-510D11.2',
 'H6PD',
 'PIK3CD-AS1',
 'AL357140.1',
 'PGD',
 'C1orf127',
 'RP4-635E18.9',
 'MASP2',
 'RP4-635E18.8',
 'SRM',
 'DRAXIN',
 'MTHFR',
 'PDPN',
 'TMEM51-AS1',
 'RP3-467K16.2',
 'RP3-467K16.4',
 'EFHD2',
 'CELA2A',
 'CELA2B',
 'RP4-680D5.9',
 'RP11-276H7.3',
 'RP4-798A10.7',
 'RP5-1182A14.5',
 'AC004824.2',
 'RP1-43E13.2',
 'PLA2G2A',
 'PLA2G2D',
 'PLA2G2C',
 'PINK1-AS',
 'RP3-329E20.2',
 'CELA3A',
 'CDC42',
 'WNT4',
 'C1QB',
 'ZNF436-AS1',
 'RP1-150O5.3',
 'I