In [None]:
import sys

sys.path.append("../")

from gears import GEARS, PertData

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


Load data. We use norman as an example.

In [2]:
pert_data = PertData("./data")
pert_data.load(data_name="norman")
pert_data.prepare_split(split="simulation", seed=1)
pert_data.get_dataloader(batch_size=32, test_batch_size=128)

Downloading...
100%|██████████| 169M/169M [00:09<00:00, 17.5MiB/s] 
Extracting zip file...
Done!
Creating pyg object for each cell in the data...
100%|██████████| 284/284 [02:21<00:00,  2.00it/s]
Saving new dataset pyg object at ./data/norman/data_pyg/cell_graphs.pkl
Done!
Creating new splits....
Saving new splits at ./data/norman/splits/norman_simulation_1_0.75.pkl
Simulation split test composition:
combo_seen0:9
combo_seen1:52
combo_seen2:18
unseen_single:37
Done!
Creating dataloaders....
Done!


Create a model object; if you use [wandb](https://wandb.ai), you can easily track model training and evaluation by setting `weight_bias_track` to true, and specify the `proj_name` and `exp_name` that you like.

In [5]:
gears_model = GEARS(
    pert_data,
    device="mps",
    weight_bias_track=False,
    proj_name="pertnet",
    exp_name="pertnet",
)
gears_model.model_initialize(hidden_size=64)

You can find available tunable parameters in model_initialize via

In [6]:
gears_model.tunable_parameters()

{'hidden_size': 'hidden dimension, default 64',
 'num_go_gnn_layers': 'number of GNN layers for GO graph, default 1',
 'num_gene_gnn_layers': 'number of GNN layers for co-expression gene graph, default 1',
 'decoder_hidden_size': 'hidden dimension for gene-specific decoder, default 16',
 'num_similar_genes_go_graph': 'number of maximum similar K genes in the GO graph, default 20',
 'num_similar_genes_co_express_graph': 'number of maximum similar K genes in the co expression graph, default 20',
 'coexpress_threshold': 'pearson correlation threshold when constructing coexpression graph, default 0.4',
 'uncertainty': 'whether or not to turn on uncertainty mode, default False',
 'uncertainty_reg': 'regularization term to balance uncertainty loss and prediction loss, default 1',
 'direction_lambda': 'regularization term to balance direction loss and prediction loss, default 1'}

Train your model:

Note: For the sake of demo, we set epoch size to 1. To get full model, set `epochs = 20`.

In [7]:
gears_model.train(epochs=1, lr=1e-3)

Start Training...
Epoch 1 Step 1 Train Loss: 0.4119
Epoch 1 Step 51 Train Loss: 0.5249
Epoch 1 Step 101 Train Loss: 0.4199
Epoch 1 Step 151 Train Loss: 0.4127
Epoch 1 Step 201 Train Loss: 0.5316
Epoch 1 Step 251 Train Loss: 0.5866
Epoch 1 Step 301 Train Loss: 0.4308
Epoch 1 Step 351 Train Loss: 0.4717
Epoch 1 Step 401 Train Loss: 0.5408
Epoch 1 Step 451 Train Loss: 0.4875
Epoch 1 Step 501 Train Loss: 0.4131
Epoch 1 Step 551 Train Loss: 0.5243
Epoch 1 Step 601 Train Loss: 0.4407
Epoch 1 Step 651 Train Loss: 0.4738
Epoch 1 Step 701 Train Loss: 0.4795
Epoch 1 Step 751 Train Loss: 0.4903
Epoch 1 Step 801 Train Loss: 0.4435
Epoch 1 Step 851 Train Loss: 0.4260
Epoch 1 Step 901 Train Loss: 0.4202
Epoch 1 Step 951 Train Loss: 0.5483
Epoch 1 Step 1001 Train Loss: 0.4399
Epoch 1 Step 1051 Train Loss: 0.4982
Epoch 1 Step 1101 Train Loss: 0.4368
Epoch 1 Step 1151 Train Loss: 0.4591
Epoch 1 Step 1201 Train Loss: 0.4688
Epoch 1 Step 1251 Train Loss: 0.4625
Epoch 1 Step 1301 Train Loss: 0.4103
Epoch 

RuntimeError: MPS backend out of memory (MPS allocated: 8.92 GB, other allocations: 9.14 GB, max allowed: 18.13 GB). Tried to allocate 78.83 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

Save and load pretrained models:

In [6]:
gears_model.save_model("test_model")
gears_model.load_pretrained("test_model")

Make prediction for new perturbation:

In [7]:
gears_model.predict([["FEV"], ["FEV", "AHR"]])

{'FEV': array([-1.5115363e-06,  4.4304952e-02,  1.0309354e-01, ...,
         3.3967001e+00,  7.8529231e-03,  1.0920237e-31], dtype=float32),
 'FEV_SAMD11': array([-2.2916190e-06,  9.7577907e-02,  1.6493453e-01, ...,
         3.2082996e+00,  7.6769367e-03,  1.7619579e-31], dtype=float32)}

Gene list can be found here:

In [8]:
gears_model.gene_list[:5]

['RP11-34P13.8', 'RP11-54O7.3', 'SAMD11', 'PERM1', 'HES4']