This hands-on exercise demonstrates how to train and use the GEARS model (Roohani Y, et al. Nature Biotechnology, 2024) with a gene perturbation dataset. You can plug the LCM into the Perturbation task by changing the gene embedding from GEARS initial embedding to LCMs' gene representation. 

Due to computational resources, we won't be calling the LCMs here to get cell representations. If you are interested, you can try it on your own after class.

We will first specifically show the content of the perturbation data, then construct the GEARS model and train it, and finally use the trained model for perturbation prediction and evaluate the effectiveness of the prediction.

# Import necessary library

In [1]:
import os
import torch
import pickle

import sys
sys.path.append('/kaggle/input/gears-tutorial/pytorch/default/1/GEARS')
from gears import GEARS
from gears.inference import evaluate
from gears.model import GEARS_Model
from gears.utils import get_similarity_network, GeneSimNetwork
from copy import deepcopy
                  
from sklearn.metrics import r2_score
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import mean_absolute_error as mae

In [2]:
sys.path.append("/kaggle/input/pert_dataset/pytorch/default/1/Pert_Data/")
import v1
from v1.utils import *
from v1.dataloader import *

# Load perturbation dataset

**Load and process data**

In [3]:
# - load pert_data and preprocess
pert_data = Byte_Pert_Data(data_dir='/kaggle/input/example-data-pert/',prefix='NormanWeissman2019_filtered',) # NormanWeissman2019_filtered or XuCao2023
pert_data.read_files()
pert_data.filter_perturbation()   # filter perturbation with less cell number
pert_data.get_and_process_adata(var_num=1000)    # process the data and obtain the higly variable genes
pert_data.data_split(split_type=1)   # split the data into train and test set
pert_data.set_control_barcode()   # set the control cell's barcode for each perturbed cell
pert_data.filter_sgRNA()  # for each pert, get the sgRNA num
pert_data.get_de_genes()  # calculate the DE genes for each perturbed cell; used for evaluation

retain_pert_num is: 196
filtered pert num is:  41


100%|██████████| 236/236 [00:04<00:00, 48.26it/s]


len of exclude_var_list is 4
len of pert_gene_list is 289
len of final var_names is 1067
this is new version


100%|██████████| 195/195 [00:04<00:00, 47.93it/s]




  0%|          | 0/48 [00:00<?, ?it/s]2025-06-18 07:14:42.840812: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750230883.023344      46 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750230883.083045      46 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

100%|██████████| 10/10 [00:00<00:00, 554.14it/s]
  2%|▏         | 1/48 [00:27<21:15, 27.14s/it]
100%|██████████| 6/6 [00:00<00:00, 129.66it/s]
  4%|▍         | 2/48 [00:27<08:44, 11.40s/it]
100%|██████████| 6/6 [00:00<00:00, 101.37it/s]
  6%|▋         | 3/48 [00:27<04:46,  6.36s/it]
100%|██████████| 6/6 [00:00<00:00, 73.72it/s]
  8%|▊         | 4/48 [00:28<02:57,  4.02s/it]
100%|██████████| 6/6 

C19orf26 not in var!



100%|██████████| 6/6 [00:00<00:00, 88.70it/s]
 12%|█▎        | 6/48 [00:29<01:20,  1.93s/it]
100%|██████████| 6/6 [00:00<00:00, 114.66it/s]
 15%|█▍        | 7/48 [00:29<00:58,  1.42s/it]
100%|██████████| 6/6 [00:00<00:00, 200.91it/s]
 17%|█▋        | 8/48 [00:29<00:42,  1.06s/it]
100%|██████████| 6/6 [00:00<00:00, 146.03it/s]
 19%|█▉        | 9/48 [00:30<00:32,  1.20it/s]
100%|██████████| 6/6 [00:00<00:00, 104.77it/s]
 21%|██        | 10/48 [00:30<00:26,  1.46it/s]
100%|██████████| 6/6 [00:00<00:00, 144.28it/s]
 23%|██▎       | 11/48 [00:30<00:21,  1.72it/s]
100%|██████████| 6/6 [00:00<00:00, 87.73it/s]
 25%|██▌       | 12/48 [00:31<00:18,  1.90it/s]
100%|██████████| 6/6 [00:00<00:00, 121.55it/s]
 27%|██▋       | 13/48 [00:31<00:16,  2.09it/s]
100%|██████████| 6/6 [00:00<00:00, 97.77it/s]
 29%|██▉       | 14/48 [00:31<00:15,  2.23it/s]
100%|██████████| 6/6 [00:00<00:00, 95.87it/s]
 31%|███▏      | 15/48 [00:32<00:14,  2.32it/s]
100%|██████████| 6/6 [00:00<00:00, 80.45it/s]
 33%|███▎  

filter_sgRNA_list is: []
max_eids < 10, no need to filter sgRNA
max_eids < 10, no need to filter sgRNA
max_eids < 10, no need to filter sgRNA
max_eids < 10, no need to filter sgRNA
max_eids < 10, no need to filter sgRNA
filter_sgRNA_list is: []
filter_sgRNA_list is: []
filter_sgRNA_list is: []
filter_sgRNA_list is: []
max_eids < 10, no need to filter sgRNA
filter_sgRNA_list is: []
filter_sgRNA_list is: []
filter_sgRNA_list is: []
max_eids < 10, no need to filter sgRNA
max_eids < 10, no need to filter sgRNA
max_eids < 10, no need to filter sgRNA
filter_sgRNA_list is: []
filter_sgRNA_list is: []
max_eids < 10, no need to filter sgRNA
filter_sgRNA_list is: []
filter_sgRNA_list is: []
filter_sgRNA_list is: []
max_eids < 10, no need to filter sgRNA
max_eids < 10, no need to filter sgRNA
max_eids < 10, no need to filter sgRNA
filter_sgRNA_list is: []
filter_sgRNA_list is: []
filter_sgRNA_list is: []
filter_sgRNA_list is: []
max_eids < 10, no need to filter sgRNA
filter_sgRNA_list is: []
max_

  0%|          | 0/195 [00:00<?, ?it/s]... storing 'guide_id' as categorical
... storing 'tissue_type' as categorical
... storing 'cell_line' as categorical
... storing 'disease' as categorical
... storing 'perturbation_type' as categorical
... storing 'celltype' as categorical
... storing 'organism' as categorical
... storing 'perturbation' as categorical
... storing 'perturbation_new' as categorical
... storing 'perturbation_type_new' as categorical
... storing 'celltype_new' as categorical
... storing 'sgRNA_new' as categorical
... storing 'perturbation_group' as categorical
... storing 'data_split' as categorical
... storing 'retain' as categorical
... storing 'control_barcode' as categorical
... storing 'sgRNA_ID' as categorical
... storing 'pert_sgRNA' as categorical
  1%|          | 1/195 [00:06<21:14,  6.57s/it]... storing 'guide_id' as categorical
... storing 'tissue_type' as categorical
... storing 'cell_line' as categorical
... storing 'disease' as categorical
... storing 'p






In [4]:
fix_seed(2024)

# - get go genes; special set for GEARS; these are used for constructing the go graph of GEARS
pert_data.get_gene2go()   # get gene2go dict: {'gene1': [go1, go2, ...], 'gene2': [go1, go3, ...], ...}
pert_data.set_pert_genes()  # get the list of genes that can be perturbed to be included in perturbation graph

# - transform dataset into gears required format
pert_data.get_Data_gears(num_de_genes = pert_data.num_de_genes,
                        dataset_name = ['train', 'test', 'val'],
                        add_control = False)
# - add necessary elements for gears
pert_data.modify_gears()

# - get dataloader
trainloader, testloader, valloader = pert_data.get_dataloader(mode='all')

Found local copy...
Found local copy...
 68%|██████▊   | 92/136 [00:03<00:01, 28.96it/s]

LYL1; IER5L | lymphoblasts not in pert_names


100%|██████████| 136/136 [00:05<00:00, 23.33it/s]
100%|██████████| 39/39 [00:01<00:00, 36.74it/s]


KIAA1804 | lymphoblasts not in pert_names
IER5L | lymphoblasts not in pert_names


100%|██████████| 20/20 [00:00<00:00, 37.22it/s]


add adata finished
add condition finished
add set2conditions finished


**Dataset details**

This dataset has 96852 cells and each of them has 1167 genes. It has 170 different perturbations. We have pre-processed the data and calculated the differential genes. These differential genes can be considered as significantly changed genes after perturbation and will be used to assess the effect of prediction.

In [5]:
pert_data.adata

AnnData object with n_obs × n_vars = 106022 × 1067
    obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_line', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo', 'perturbation_new', 'perturbation_type_new', 'nperts_new', 'celltype_new', 'sgRNA_new', 'perturbation_group', 'data_split', 'retain', 'n_genes', 'n_counts_all', 'control_barcode', 'sgRNA_ID', 'pert_sgRNA', 'condition', 'condition_name'
    var: 'ensemble_id', 'ncounts', 'ncells', 'gene_name'
    uns: 'rank_genes_groups', 'pvals', 'pvals_adj', 'scores', 'logfoldchanges', 'top_non_dropout_de_20', 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'top_non_zero_de_20', 'rank_genes_groups_cov_all'

In [6]:
pert_data.adata[pert_data.adata.obs['data_split']=='test'].obs['perturbation_new'].unique()

array(['KMT2A', 'MAP2K3; ELMSAN1', 'UBASH3B; CNN1', 'CEBPE; RUNX1T1',
       'ETS2; MAPK1', 'SET; CEBPE', 'EGR1', 'FOXO4', 'CNN1', 'CELF2',
       'PTPN12; PTPN9', 'ZNF318', 'ETS2; MAP7D1', 'DUSP9', 'MEIS1',
       'ELMSAN1', 'POU3F2; FOXL2', 'MAP7D1', 'FOXL2', 'ETS2; PRTG',
       'CSRNP1', 'ATL1', 'TP73', 'RREB1', 'TBX2', 'CDKN1C', 'KLF1; CEBPA',
       'FOSB', 'CBFA2T3', 'ETS2; IKZF3', 'RHOXF2; ZBTB25', 'KIAA1804',
       'CDKN1A', 'MAPK1; PRTG', 'IER5L', 'CITED1', 'ZC3HAV1',
       'PTPN12; SNAI1', 'CEBPB; OSR2'], dtype=object)

In [7]:
pert_data.adata[pert_data.adata.obs['data_split']=='train'].obs['perturbation_new'].unique()

array(['CEBPB', 'TGFBR2; IGDCC3', 'control', 'SNAI1', 'CNN1; MAPK1',
       'GLB1L2', 'OSR2', 'MAP2K6', 'LHX1; ELMSAN1', 'SET; KLF1', 'CBL',
       'CEBPE', 'DUSP9; KLF1', 'S1PR2', 'CEBPE; PTPN12', 'UBASH3B; PTPN9',
       'KLF1', 'TMSB4X; BAK1', 'SLC6A9', 'ETS2; IGDCC3', 'AHR; FEV',
       'STIL', 'ARRDC3', 'TGFBR2; C19orf26', 'TSC22D1', 'COL2A1',
       'FOXF1; HOXB9', 'FOXA1', 'BAK1', 'ETS2', 'CBL; CNN1', 'MAP2K3',
       'SLC4A1', 'PTPN12; OSR2', 'COL1A1', 'LYL1', 'ZNF318; FOXL2',
       'TMSB4X', 'KLF1; CLDN6', 'MAPK1; TGFBR2', 'ISL2', 'FOXA1; HOXB9',
       'ZBTB10; PTPN12', 'FOXF1', 'BCL2L11', 'MAML2', 'AHR; KLF1', 'HK2',
       'CEBPB; PTPN12', 'PRDM1', 'IKZF3', 'KLF1; TGFBR2', 'ETS2; CEBPE',
       'FOSB; PTPN12', 'UBASH3A', 'MAP4K3', 'FOXA3', 'PTPN12; ZBTB25',
       'FOXA3; FOXA1', 'DLX2', 'CBL; PTPN9', 'UBASH3B', 'SAMD1',
       'TGFBR2; ETS2', 'SGK1; TBX3', 'HOXB9', 'LHX1', 'HOXC13',
       'TBX3; TBX2', 'BCORL1', 'C19orf26', 'UBASH3B; PTPN12', 'HOXA13',
       'IGDCC3; MA

In [8]:
# In the perturbation dataset, each cell has one (or more than one) perturbed gene
pert_data.adata.obs['perturbation_new'][:5]

AAACCTGAGAAACCAT             CEBPB
AAACCTGAGAAAGTGG    TGFBR2; IGDCC3
AAACCTGAGAAGAAGC           control
AAACCTGAGAAGGTTT             SNAI1
AAACCTGAGACATAAC       CNN1; MAPK1
Name: perturbation_new, dtype: object

In [9]:
# And also a control group cell, representing the cell before perturbation
pert_data.adata.obs['control_barcode'][:5]

AAACCTGAGAAACCAT    GGAATAACAAGGTTCT
AAACCTGAGAAAGTGG    GGAATAACAAGGTTCT
AAACCTGAGAAGAAGC                None
AAACCTGAGAAGGTTT    GGAATAACAAGGTTCT
AAACCTGAGACATAAC    GGAATAACAAGGTTCT
Name: control_barcode, dtype: object

In [10]:
# We obtain the differentially expressed genes (DEGs) for each perturbation; 
# Here we chose the top 20 non-zero expression DEGs for each perturbation from the adata.uns['rank_genes_groups_cov_all']
print( dict(list(pert_data.adata.uns['top_non_zero_de_20'].items())[:5]) )

{'CEBPB | lymphoblasts': array(['PLD3', 'SH3BGRL3', 'AIF1', 'LST1', 'FTL', 'RP11-301G19.1',
       'TMSB10', 'TMSB4X', 'CFD', 'LGI2', 'ID1', 'ID3', 'GYPA', 'HBG2',
       'GAL', 'MSRB1', 'CSF3R', 'ARPC1B', 'TYROBP', 'MYO1F'], dtype='<U13'), 'TGFBR2; IGDCC3 | lymphoblasts': array(['IGDCC3', 'ALAS2', 'HBZ', 'HBG2', 'RP11-301G19.1', 'GYPB',
       'PRSS57', 'HBG1', 'BST2', 'CPEB4', 'FAM83A', 'GYPA', 'HBA1',
       'IGFBP2', 'YBX1', 'MDK', 'TMSB10', 'GAL', 'AC079466.1', 'APOE'],
      dtype='<U13'), 'SNAI1 | lymphoblasts': array(['TMSB10', 'HBG2', 'S100A13', 'LGALS1', 'HIST1H1C', 'RP11-717F1.1',
       'CFD', 'RNASET2', 'GAL', 'HBZ', 'HIST1H2BJ', 'MDK', 'PRSS57',
       'MT-ND2', 'ALAS2', 'BST2', 'YBX1', 'VIM', 'S100A11', 'NCL'],
      dtype='<U12'), 'CNN1; MAPK1 | lymphoblasts': array(['CNN1', 'TMSB4X', 'MAPK1', 'GAL', 'ETS2', 'AIF1', 'SH3BGRL3',
       'HBG2', 'CTSL', 'GMFG', 'AC079466.1', 'ZFP36L1', 'PRSS57',
       'S100A11', 'RPL3', 'ARHGDIB', 'ISG15', 'HBG1', 'MDK', 'COTL1'],
      d

# Initialize the GEARS model

Set GEARS model parameters

In [11]:
# - init gears model
pert_data.adata_split.X = pert_data.adata_split.X.toarray()
gears_model = GEARS(pert_data, device = 'cuda:0', 
                        weight_bias_track = False, 
                        proj_name = 'pertnet', 
                        exp_name = 'pertnet')

# - set model configuration
gears_model.config = {'hidden_size': 64,
                'num_go_gnn_layers' : 1, 
                'num_gene_gnn_layers' : 1,
                'decoder_hidden_size' : 16,
                'num_similar_genes_go_graph' : 20,
                'num_similar_genes_co_express_graph' : 20,
                'coexpress_threshold': 0.4,
                'uncertainty' : False, 
                'uncertainty_reg' : 1,
                'direction_lambda' : 1e-1,
                'G_go': None,
                'G_go_weight': None,
                'G_coexpress': None,
                'G_coexpress_weight': None,
                'device': gears_model.device,
                'num_genes': gears_model.num_genes,
                'num_perts': gears_model.num_perts,
                'no_perturb': False
                }

Construct co-expression graph and go graph. These two graphs are used to build the graph neural network.

In [None]:
# - Set the gene co expression network (green graph)
if gears_model.config['G_coexpress'] is None:  # If the co-expression graph is not already set
    ## calculating co expression similarity graph
    edge_list = get_similarity_network(         # Compute the gene co-expression similarity network
        network_type='co-express',              # Specify the network type as co-expression
        adata=gears_model.adata,                # Pass the AnnData object with gene expression data
        threshold=gears_model.config['coexpress_threshold'],  # Threshold for edge creation in the graph
        k=gears_model.config['num_similar_genes_co_express_graph'],  # Number of similar genes to connect
        data_path=gears_model.data_path,        # Path to the data directory
        data_name=gears_model.dataset_name,     # Name of the dataset
        split=gears_model.split, seed=gears_model.seed,  # Data split and random seed for reproducibility
        train_gene_set_size=gears_model.train_gene_set_size,  # Size of the training gene set
        set2conditions=gears_model.set2conditions           # Mapping from gene sets to experimental conditions
    )

    sim_network = GeneSimNetwork(               # Create a GeneSimNetwork object from the edge list
        edge_list,                             # The computed edge list
        gears_model.gene_list,                 # List of genes (nodes) in the graph
        node_map=gears_model.node_map          # Mapping from gene names to node indices
    )
    gears_model.config['G_coexpress'] = sim_network.edge_index      # Store the edge indices in the config
    gears_model.config['G_coexpress_weight'] = sim_network.edge_weight  # Store the edge weights in the config

# - Set the gene ontology network (red graph)
if gears_model.config['G_go'] is None:  # If the gene ontology (GO) graph is not already set
    ## calculating gene ontology similarity graph
    edge_list = get_similarity_network(         # Compute the gene ontology similarity network
        network_type='go',                      # Specify the network type as gene ontology
        adata=gears_model.adata,                # Pass the AnnData object
        threshold=gears_model.config['coexpress_threshold'],  # Threshold for edge creation
        k=gears_model.config['num_similar_genes_co_express_graph'],  # Number of similar genes to connect
        pert_list=gears_model.pert_list,        # List of perturbations (for GO graph)
        data_path=gears_model.data_path,        # Path to the data directory
        data_name=gears_model.dataset_name,     # Name of the dataset
        split=gears_model.split, seed=gears_model.seed,  # Data split and random seed
        train_gene_set_size=gears_model.train_gene_set_size,  # Size of the training gene set
        set2conditions=gears_model.set2conditions,            # Mapping from gene sets to conditions
        default_pert_graph=gears_model.default_pert_graph     # Default perturbation graph (if any)
    )

    sim_network = GeneSimNetwork(               # Create a GeneSimNetwork object for the GO graph
        edge_list,                             # The computed edge list
        gears_model.pert_list,                 # List of perturbations (nodes in GO graph)
        node_map=gears_model.node_map_pert     # Mapping from perturbation names to node indices
    )
    gears_model.config['G_go'] = sim_network.edge_index         # Store the edge indices in the config
    gears_model.config['G_go_weight'] = sim_network.edge_weight # Store the edge weights in the config

# - finally obtain the model
gears_model.model = GEARS_Model(gears_model.config).to(gears_model.device)  # Instantiate and move the model to the device (e.g., GPU)
gears_model.best_model = deepcopy(gears_model.model)                        # Save a copy as the best model so far

Found local copy...


this is lichen version 2!


# Training

Here we use a simplified pseudocode to demonstrate the core part of GEARS model. For the details of each line of code, you can refer to the source code.

**Get Base Gene Embeddings**  
In the following loop we use co-expression graph to obtain the final gene embeddings. (Green part in the figure above)
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
base_emb = self.gene_emb(self.num_genes)
pos_emb = self.emb_pos(self.num_genes)
for idx, gnn_layer in enumerate(self.gnn_layers_exp):
    pos_emb = gnn_layer(pos_emb, self.G_coexpress, self.G_coexpress_weight)
base_emb = base_emb + pos_emb
</code></pre>
</div>
Note that the gene embedding here can be directly replaced with LCM gene embedding (Such as scFoundation's embedding). these LCM embedding are pre-trained on huge amount of data and have richer and more comprehensive cellular information, which can enhance the performance in the Perturbation task.

**Get Perturbation Embeddings**
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
pert_global_emb = self.pert_emb(self.num_perts)
</code></pre>
</div>     

In the following loop we use GO graph to obtain the final perturbation embeddings. (Red part in the figure above)
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
for idx, gnn_layer in enumerate(self.gnn_layers_go):
    pert_global_emb = gnn_layer(pert_global_emb, self.G_sim, self.G_sim_weight)
</code></pre>
</div>

**Add Global Perturbation Embedding to Each Gene in Each Cell in the Batch**  
(Composition Operator in the figure above)  
Select the perturbation embedding of the coresponding gene
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
pert_track = {}
for i, j in enumerate(pert_index[0]):
    pert_track[j.item()] = pert_global_emb[pert_index[1][i]]
</code></pre>
</div> 
Add the selected perturbation embedding to the gene embedding
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
emb_total = self.pert_fuse(torch.stack(list(pert_track.values())))
for idx, j in enumerate(pert_track.keys()):
    base_emb[j] = base_emb[j] + emb_total[idx]
</code></pre>
</div> 

**Finally Go Through A Fully Connected Network to Obtain the Predicted Expression**  
(Blue part in the figure above)  
<div style="background:#f9f9f9; border:1px solid #ccc; border-radius:6px; padding:12px; color:#333; font-family:monospace;">
<pre><code>
predict_expression = self.fcn(base_emb)
</code></pre>
</div> 

Here we do not specifically delve into the training process of the model. The detail training process can be directly referred to the __gears_model.train__.

In [13]:
gears_model.train(epochs = 1, lr = 1e-4)

Start Training...
Epoch 1 Step 1 Train Loss: 0.6396
Epoch 1 Step 51 Train Loss: 0.7346
Epoch 1 Step 101 Train Loss: 0.8164
Epoch 1 Step 151 Train Loss: 0.7041
Epoch 1 Step 201 Train Loss: 0.7487
Epoch 1 Step 251 Train Loss: 0.6544
Epoch 1 Step 301 Train Loss: 0.6141
Epoch 1 Step 351 Train Loss: 0.5417
Epoch 1 Step 401 Train Loss: 0.6686
Epoch 1 Step 451 Train Loss: 0.6415
Epoch 1 Step 501 Train Loss: 0.6213
Epoch 1 Step 551 Train Loss: 0.6144
Epoch 1 Step 601 Train Loss: 0.6536
Epoch 1 Step 651 Train Loss: 0.6656
Epoch 1 Step 701 Train Loss: 0.5710
Epoch 1 Step 751 Train Loss: 0.7103
Epoch 1 Step 801 Train Loss: 0.6331
Epoch 1 Step 851 Train Loss: 0.8571
Epoch 1 Step 901 Train Loss: 0.6516
Epoch 1 Step 951 Train Loss: 0.6747
Epoch 1 Step 1001 Train Loss: 0.7110
Epoch 1 Step 1051 Train Loss: 0.6432
Epoch 1 Step 1101 Train Loss: 0.5850
Epoch 1 Step 1151 Train Loss: 0.6577
Epoch 1 Step 1201 Train Loss: 0.6702
Epoch 1 Step 1251 Train Loss: 0.6095
Epoch 1 Step 1301 Train Loss: 0.5903
Epoch 

# Evaluate the performance  
Here we use the test set to evaluate the model's performance. All the perturbed genes in the test set are not seen in the training set.

In [14]:
# - Model testing
test_output = evaluate(testloader, gears_model.best_model,
                    gears_model.config['uncertainty'], gears_model.device)

We can then calculate the evaluation metric

In [15]:
pert_metric = {}
# here we take one perturbation as an example
# this function will first prepare the data for metric calculation
de_idx_map, ctrl, p_idx = prepare_for_metric(pert_data.adata, test_output, pert = np.unique(test_output['pert_cat'])[0],
                                              most_variable_genes=None, p_thre_1=0.01, p_thre_2=0.1)
# de_idx_map: the gene index of top20, top50 and top100 DEGs
# ctrl: the control group cell for this perturbed gene
# p_idx: the perturbed cell index for this perturbed gene

We use the following metrics for evaluation

**pearson correlation**: measures the linear correlation between predicted and true values  
**mean squared error**: measures the average of the squares of the errors  
**mean absolute error**: measures the average of the absolute errors  
**change ratio**: measures the relative change between predicted and true values  
**spearman correlation**: measures the monotonic relationship between predicted and true values  

In [None]:
# Define a dictionary mapping metric names to their corresponding functions
metric2fct = {
    'pearson': pearsonr,           # Pearson correlation coefficient
    'mse': mse,                    # Mean squared error
    'mae': mae,                    # Mean absolute error
    'change_ratio': get_change_ratio,  # Custom function for change ratio
    'spearman': spearmanr,         # Spearman correlation coefficient
}
name = 'DE_'  # Prefix for metric names in the results dictionary

# Loop over each DEG set (e.g., 'top20', 'top50', 'top100')
for prefix in list(de_idx_map.keys()):  # de_idx_map contains indices for each DEG set
    de_idx = de_idx_map[prefix]         # Get the indices for the current DEG set
    for m, fct in metric2fct.items():   # Loop over each metric and its function
        if m == 'pearson' or m == 'spearman':
            # Calculate the metric on delta expression (perturbed - control)
            val = fct(
                test_output['pred'][p_idx].mean(0)[de_idx] - ctrl[de_idx],   # Predicted delta
                test_output['truth'][p_idx].mean(0)[de_idx] - ctrl[de_idx]   # True delta
            )[0]
            if np.isnan(val):  # Handle NaN results
                val = 0
            pert_metric[name + m + f'_delta_{prefix}'] = val  # Store delta metric

            # Calculate the metric on raw predicted vs. true expression
            val = fct(
                test_output['pred'][p_idx].mean(0)[de_idx],    # Predicted expression
                test_output['truth'][p_idx].mean(0)[de_idx]    # True expression
            )[0]
            if np.isnan(val):
                val = 0
            pert_metric[name + m + f'_{prefix}'] = val         # Store raw metric
        elif m == 'change_ratio':
            # Calculate the change ratio metric (custom function)
            val = fct(
                test_output['pred'][p_idx].mean(0)[de_idx],
                test_output['truth'][p_idx].mean(0)[de_idx]
            )
            pert_metric[name + m + f'_{prefix}'] = val
        else:
            # For mse and mae, calculate on delta expression (perturbed - control)
            val = fct(
                test_output['pred'][p_idx].mean(0)[de_idx] - ctrl[de_idx],
                test_output['truth'][p_idx].mean(0)[de_idx] - ctrl[de_idx]
            )
            pert_metric[name + m +

In [17]:
pert_metric

{'DE_pearson_delta_top20': 0.6932509,
 'DE_pearson_top20': 0.9079151,
 'DE_mse_top20': 0.29012805,
 'DE_mae_top20': 0.45869684,
 'DE_change_ratio_top20': 0.35481155,
 'DE_spearman_delta_top20': 0.6962406015037594,
 'DE_spearman_top20': 0.8736842105263156,
 'DE_pearson_delta_top50': 0.6354232,
 'DE_pearson_top50': 0.9486616,
 'DE_mse_top50': 0.14973007,
 'DE_mae_top50': 0.32067817,
 'DE_change_ratio_top50': 0.31630465,
 'DE_spearman_delta_top50': 0.6956542617046818,
 'DE_spearman_top50': 0.892436974789916,
 'DE_pearson_delta_top100': 0.54854935,
 'DE_pearson_top100': 0.96458554,
 'DE_mse_top100': 0.08678662,
 'DE_mae_top100': 0.2287849,
 'DE_change_ratio_top100': 0.33833167,
 'DE_spearman_delta_top100': 0.5639603960396039,
 'DE_spearman_top100': 0.9119711971197119,
 'DE_pearson_delta_top200': 0.46559995,
 'DE_pearson_top200': 0.9763876,
 'DE_mse_top200': 0.045821756,
 'DE_mae_top200': 0.13980067,
 'DE_change_ratio_top200': 1.1450596,
 'DE_spearman_delta_top200': 0.42882098767629945,
 'D

### Only trained for 1 epoch so didn't get good results - can try running for more epochs
Can also change the training dataset