# Example for key regulators inference during cell state transition: transcriptome

To comprehensively infer key regulators during cell state transitions, it is crucial to integrate analyses of chromatin accessibility and the transcriptome. In this tutorial, we will demonstrate how to use ChromBERT to infer key regulators involved in a specific transdifferentiation process (fibroblast to myoblast) through transcriptome analysis.

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
import chrombert
import pandas as pd
import numpy as np
from torchinfo import summary
import subprocess
import torch
import lightning.pytorch as pl
import glob
from tqdm import tqdm
import torchmetrics as tm 
base_dir =  os.path.expanduser("~/.cache/chrombert/data") ### to_path_chrombert/data

## Preprocess dataset  

This section walks you through preparing raw transcriptome data, including TSS and TPM values for each gene, and transforming it into the format required by ChromBERT.  

To identify key regulators, we will fine-tune ChromBERT to predict log1p-transformed gene expression fold changes during transdifferentiation. This requires careful preparation of the transformed data. Additionally, to analyze shifts in regulator embeddings, we need to identify and prepare datasets for both upregulated and unchanged genes.  


In [2]:
# We provide tables of gene expression data for fibroblast and myoblast.
gep_dir = f'{base_dir}/demo/transdifferentiation/transcriptome'
fibroblast_exp = pd.read_csv(f'{gep_dir}/fibroblast_expression.csv')
myoblast_exp = pd.read_csv(f'{gep_dir}/myoblast_expression.csv')
myoblast_exp.head(), fibroblast_exp.head()

(   chrom       tss          gene_id        tpm
 0  chr19  58353492  ENSG00000121410  22.236894
 1  chr19  58347718  ENSG00000268895   9.317134
 2  chr10  50885675  ENSG00000148584   0.000000
 3  chr12   9116229  ENSG00000175899   0.993828
 4  chr12   9065163  ENSG00000245105   0.124228,
    chrom       tss          gene_id        tpm
 0  chr19  58353492  ENSG00000121410  12.774133
 1  chr19  58347718  ENSG00000268895   2.939181
 2  chr10  50885675  ENSG00000148584   0.000000
 3  chr12   9116229  ENSG00000175899   0.226091
 4  chr12   9065163  ENSG00000245105   0.226091)

### Prepare the log1p-transformed gene expression fold changes data

In [3]:
from chrombert.scripts.chrombert_make_dataset import get_regions
# We merge the two datasets, and calculate the log1p-transformed gene expression fold changes.
merge_exp = pd.merge(fibroblast_exp,myoblast_exp,left_on=['chrom','tss','gene_id'],right_on=['chrom','tss','gene_id'],suffixes=['_fibroblast','_myoblast'])
merge_exp['fold_change']= np.log1p(merge_exp['tpm_myoblast']) - np.log1p(merge_exp['tpm_fibroblast'])
merge_exp['start'] = merge_exp['tss']//1000 * 1000
merge_exp['end'] = (merge_exp['tss']//1000 + 1) * 1000
foldchange_exp = merge_exp [['chrom','start','end','tss','gene_id','fold_change']]

# align genomic coordinates to the predefined 1-kb bins
chrom_regions = get_regions(base_dir,genome='hg38',high_resolution=False) # 1kb
chrom_regions
chrom_regions_df = pd.read_csv(chrom_regions,sep='\t',names=['chrom','start','end','build_region_index'])
chrom_regions_df
merge_region = pd.merge(foldchange_exp,chrom_regions_df,left_on=['chrom','start','end'],right_on=['chrom','start','end'],how='inner')[['chrom','start','end','build_region_index','fold_change','tss','gene_id']]
gep_df = merge_region.rename(columns={'fold_change':'label'})
gep_df.to_csv(f'{gep_dir}/fibroblast_to_myoblast_expression_changes.csv',index=False)
gep_df.head() ### This label represents log1p-transformed gene expression fold change data."

Unnamed: 0,chrom,start,end,build_region_index,label,tss,gene_id
0,chr19,58353000,58354000,917950,0.522949,58353492,ENSG00000121410
1,chr19,58347000,58348000,917944,0.962833,58347718,ENSG00000268895
2,chr10,50885000,50886000,221904,0.0,50885675,ENSG00000148584
3,chr12,9116000,9117000,393001,0.486225,9116229,ENSG00000175899
4,chr12,9065000,9066000,392961,-0.086734,9065163,ENSG00000245105


### Prepare the fine-tuning data
We split the data into training, testing, and validation sets with an 8:1:1 ratio and downsample the data to test the fine-tuning process.


In [4]:
train_data = gep_df.sample(frac=0.8,random_state=55)
test_data = gep_df.drop(train_data.index).sample(frac=0.5,random_state=55)
valid_data = gep_df.drop(train_data.index).drop(test_data.index)
train_data_sample = train_data.sample(n=80,random_state=55)
test_data_sample = test_data.sample(n=50,random_state=55)
valid_data_sample = valid_data.sample(n=20,random_state=55)
train_data_sample.to_csv(f'{gep_dir}/train_sample.csv',index=False)
test_data_sample.to_csv(f'{gep_dir}/test_sample.csv',index=False)
valid_data_sample.to_csv(f'{gep_dir}/valid_sample.csv',index=False)
train_data_sample.head()

Unnamed: 0,chrom,start,end,build_region_index,label,tss,gene_id
4496,chr6,138795000,138796000,1719247,0.0,138795911,ENSG00000203734
13237,chr8,96261000,96262000,1933979,-0.351699,96261902,ENSG00000156471
17079,chr22,29555000,29556000,1186796,-0.565257,29555216,ENSG00000100296
4118,chr8,1622000,1623000,1866390,0.0,1622417,ENSG00000253267
13482,chr7,66682000,66683000,1792879,0.402664,66682164,ENSG00000154710


### Prepare the upregulated genes and unchanged genes
Here we identify the upregulated genes and unchanged genes, based on the log1p-transformed gene expression fold changes.
In addition, we downsample the list of upregulated and unchanged genes to save time.

In [5]:
up_data = gep_df[gep_df['label']>1]
nochange_data = gep_df[(gep_df['label']>-0.5) & (gep_df['label']<0.5)]

up_data_sample = up_data.sample(n=100,random_state=55)
nochange_data_sample = nochange_data.sample(n=100,random_state=55)


up_data_sample.to_csv(f'{gep_dir}/up_data_sample.csv',index=False)
nochange_data_sample.to_csv(f'{gep_dir}/nochange_data_sample.csv',index=False)

## Fine-tune  

This section provides a tutorial for fine-tuning ChromBERTs to predict genome-wide changes in the transcriptome. The parameters for transcriptome changes are adjusted to account for the inclusion of nearby flank regions for each TSS. The process involves the following key modifications:  

- **Dataset Configuration:** Use the `multi_flank_window` preset for dataset configuration and set the `flank_window` parameter to four.  
- **Model Instantiation:** Use the `gep` preset for model configuration and set the `gep_flank_window` parameter to four.  



### Configure dataset and data module

In [6]:
# We use the `multi_flank_window` preset dataset config and set the `flank_window` parameter to 4.
# The flank_window parameter is used to set the flank window for the dataset configuration. 
# "4" represents +/- 4 nearest genomic regions to the TSS were used.
dataset_config = chrombert.get_preset_dataset_config(
    "multi_flank_window",
    supervised_file = None, 
    batch_size = 2, 
    num_workers = 4,
    flank_window=4 
    )
dataset_config

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


DatasetConfig({'hdf5_file': '/home/chenqianqian/.cache/chrombert/data/hg38_6k_1kb.hdf5', 'supervised_file': None, 'kind': 'MultiFlankwindowDataset', 'meta_file': '/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_meta.json', 'ignore': False, 'ignore_object': None, 'batch_size': 2, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'perturbation': False, 'perturbation_object': None, 'perturbation_value': 0, 'prompt_kind': None, 'prompt_regulator': None, 'prompt_regulator_cache_file': None, 'prompt_celltype': None, 'prompt_celltype_cache_file': None, 'prompt_regulator_cache_pin_memory': False, 'prompt_regulator_cache_limit': 3, 'fasta_file': None, 'flank_window': 4})

In [7]:
gep_dir = f'{base_dir}/demo/transdifferentiation/transcriptome'

In [8]:
# We use the `LitChromBERTFTDataModule` to create a data module for fine-tuning.
data_module = chrombert.LitChromBERTFTDataModule(
    config = dataset_config, 
    train_params = {'supervised_file': f'{gep_dir}/train_sample.csv'}, 
    val_params = {'supervised_file':f'{gep_dir}/valid_sample.csv'}, 
    test_params = {'supervised_file':f'{gep_dir}/test_sample.csv'}
)
data_module.setup()

### Configure model and instantiation


In [9]:
# We use the `gep` preset model config and also set the `gep_flank_window` parameter to four.

model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4)
model_config

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt


ChromBERTFTConfig:
{
    "genome": "hg38",
    "task": "gep",
    "dim_output": 1,
    "mtx_mask": "/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_mask_matrix.tsv",
    "dropout": 0.1,
    "pretrain_ckpt": "/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt",
    "finetune_ckpt": null,
    "ignore": false,
    "ignore_index": [
        null,
        null
    ],
    "gep_flank_window": 4,
    "gep_parallel_embedding": false,
    "gep_gradient_checkpoint": false,
    "gep_zero_inflation": false,
    "prompt_kind": "cistrome",
    "prompt_dim_external": 512,
    "dnabert2_ckpt": null
}

In [10]:
model = model_config.init_model()
model.freeze_pretrain(2) ### freeze chrombert 6 transformer blocks during fine-tuning  
summary(model)

use organisim hg38; max sequence length is 6391


  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))


Layer (type:depth-idx)                                       Param #
ChromBERTGEP                                                 --
├─PoolFlankWindow: 1-1                                       --
│    └─ChromBERT: 2-1                                        --
│    │    └─BERTEmbedding: 3-1                               (4,916,736)
│    │    └─ModuleList: 3-2                                  51,978,240
├─GeneralHeader: 1-2                                         --
│    └─CistromeEmbeddingManager: 2-2                         --
│    └─Conv2d: 2-3                                           769
│    └─ReLU: 2-4                                             --
│    └─ResidualBlock: 2-5                                    --
│    │    └─Linear: 3-3                                      1,099,776
│    │    └─Linear: 3-4                                      1,049,600
│    │    └─LayerNorm: 3-5                                   2,048
│    │    └─Linear: 3-6                                      1,0

### Configure training parameters and fine-tune  

The model is fine-tuned using PyTorch Lightning with a straightforward parameter configuration. To save time, fine-tuning is performed on a smaller dataset.  

*Note:* Due to the stochastic nature of the tuning process, results may vary. For improved performance, consider increasing the number of epochs and expanding the dataset size.  


In [11]:
train_config = chrombert.finetune.TrainConfig(
    kind='regression',        
    loss='rmse',
    max_epochs=2,
    accumulate_grad_batches=2,
    val_check_interval=2,
    limit_val_batches=10,
    tag='gep',
    checkpoint_mode='max',
    checkpoint_metric='pcc'
    )
pl_module = train_config.init_pl_module(model) # wrap model with PyTorch Lightning module
type(pl_module)



chrombert.finetune.train.pl_module.RegressionPLModule

Then we start fine-tuning!      
The trainer will save logs in a format compatible with TensorBoard, and the checkpoint with the lowest validation loss will be saved during the process.



In [12]:
callback_ckpt = pl.callbacks.ModelCheckpoint(monitor = f"{train_config.tag}_validation/{train_config.loss}", mode = "min")
trainer = pl.Trainer(
    max_epochs=train_config.max_epochs,
    log_every_n_steps=1, 
    limit_val_batches = train_config.limit_val_batches,
    val_check_interval = train_config.val_check_interval,
    accelerator="gpu", 
    accumulate_grad_batches= train_config.accumulate_grad_batches, 
    fast_dev_run=False, 
    precision="bf16-mixed",
    strategy="auto",
    callbacks=[
        pl.callbacks.LearningRateMonitor(),
        callback_ckpt,   
    ],
    logger=pl.loggers.TensorBoardLogger("lightning_logs", name='gep'))
trainer.fit(pl_module,data_module)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: lightning_logs/gep
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name  | Type         | Params | Mode 
-----------------------------------------------
0 | model | ChromBERTGEP | 62.8 M | train
-----------------------------------------------
18.9 M    Trainable params
43.9 M    Non-trainable params
62.8 M    Total params
251.095   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.



## Evaluate the fine-tuned model  

ChromBERT has been successfully fine-tuned! You can evaluate the model directly using `pl_module.model`. However, note that due to flash-attention settings, the dropout probability cannot be adjusted by `model.eval()`, which may introduce some randomness to the output.  

So, to ensure consistent results,it's suggested to save the fine-tuned checkpoint and reload it into the original model configuration for accurate evaluation.  

In addition, we use only downsampled test data for evaluation in this tutorial to save time.

### Evaluate the model fine-tuned with limited data

We first load the fine-tuned model and evaluate it on the downsampled test data here.


In [13]:
gep_ft_ckpt = os.path.abspath(glob.glob('./lightning_logs/gep/version*/checkpoints/*.ckpt')[0])
model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4, dropout=0)
ft_model = model_config.init_model(finetune_ckpt = gep_ft_ckpt)

dl = data_module.test_dataloader()
ft_model.cuda()

with torch.no_grad():
    y_preds = []
    y_labels = []
    for idx, batch in enumerate(tqdm(dl, total=len(dl))):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch).cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391


  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))


Loading checkpoint from /shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/examples/tutorials/lightning_logs/gep/version_0/checkpoints/epoch=0-step=2.ckpt


  new_state = torch.load(ckpt)


Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]

{'pearsonr': tensor(-0.0272), 'spearmanr': tensor(0.0601), 'mse': tensor(0.5545), 'mae': tensor(0.5203), 'r2': tensor(-0.1555)}





### Performance of the Fine-Tuned Model  

The model fine-tuned with limited data demonstrated suboptimal performance during evaluation. To address this, we provided a checkpoint fine-tuned on the entire dataset and will evaluate its performance here.  



In [14]:
gep_ft_ckpt = f'{gep_dir}/gep_fibroblast_to_myoblast.ckpt'
model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4, dropout=0)
ft_model = model_config.init_model(finetune_ckpt = gep_ft_ckpt)

dl = data_module.test_dataloader()
ft_model.cuda()

with torch.no_grad():
    y_preds = []
    y_labels = []
    for batch in tqdm(dl,total=len(dl)):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch).cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391


  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))
  new_state = torch.load(ckpt)


Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters


100%|██████████| 25/25 [00:23<00:00,  1.08it/s]

{'pearsonr': tensor(0.7469), 'spearmanr': tensor(0.5639), 'mse': tensor(0.2352), 'mae': tensor(0.3242), 'r2': tensor(0.5099)}





## Infer key regulators  

Using the fine-tuned model, we analyze embedding similarities between upregulated and unchanged genomic regions. Lower embedding similarity is hypothesized to indicate a greater functional shift, highlighting the potential role of key regulators in cell state transitions.  

To save time, we perform the analysis on a downsampled set of upregulated and unchanged genomic regions.  

*Note:* Only the selected factors are considered for this analysis, excluding histone modifications and chromatin accessibility.  


In [15]:
# Load the fine-tuned model
model_tuned = chrombert.get_preset_model_config(
    "gep", 
    gep_flank_window = 4,
    dropout = 0,
    finetune_ckpt = f'{gep_dir}/gep_fibroblast_to_myoblast.ckpt').init_model() # use absolute path here, to avoid mixing of preset
# Get the embedding manager
model_emb = model_tuned.get_embedding_manager().cuda()
summary(model_emb)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
use organisim hg38; max sequence length is 6391


  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))


Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters


  new_state = torch.load(ckpt)


Layer (type:depth-idx)                                       Param #
ChromBERTEmbedding                                           --
├─PoolFlankWindow: 1-1                                       --
│    └─ChromBERT: 2-1                                        --
│    │    └─BERTEmbedding: 3-1                               4,916,736
│    │    └─ModuleList: 3-2                                  51,978,240
├─CistromeEmbeddingManager: 1-2                              --
Total params: 56,894,976
Trainable params: 56,894,976
Non-trainable params: 0

### Gather regulator embeddings in upregulated gene

In [16]:
dataset_config = chrombert.get_preset_dataset_config("multi_flank_window",supervised_file = f'{gep_dir}/up_data_sample.csv', batch_size = 32, num_workers = 4)
dl = dataset_config.init_dataloader()
up_gep_embs = []
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        up_gep_embs.append(emb)
up_gep_embs = torch.cat(up_gep_embs)
up_gep_embs.shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


100%|██████████| 4/4 [00:43<00:00, 10.95s/it]


torch.Size([100, 1073, 768])

### Gather regulator embeddings in unchanged genes

In [17]:
dataset_config = chrombert.get_preset_dataset_config("multi_flank_window",supervised_file = f'{gep_dir}/nochange_data_sample.csv', batch_size = 32, num_workers = 4)
dl = dataset_config.init_dataloader()
nochange_gep_embs = []
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        nochange_gep_embs.append(emb)
nochange_gep_embs = torch.cat(nochange_gep_embs)
nochange_gep_embs.shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


100%|██████████| 4/4 [00:43<00:00, 10.98s/it]


torch.Size([100, 1073, 768])

We consider only factors below, remove histone modifications and chromatin accessibility. 

In [18]:
with open(os.path.join(base_dir, "config","hg38_6k_factors_list.txt"),"r") as f:
    factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]


indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
up_gep_embs = up_gep_embs.mean(axis=0)[indices]
nochange_gep_embs = nochange_gep_embs.mean(axis=0)[indices]
up_gep_embs.shape, nochange_gep_embs.shape

(torch.Size([991, 768]), torch.Size([991, 768]))

### Analyze the embedding similarity between upregulated and unchanged genes

We calculate the cosine similarity between the embeddings of upregulated and unchanged genomic regions to identify key regulators.


In [19]:
from sklearn.metrics.pairwise import cosine_similarity
gep_similarity = [cosine_similarity(up_gep_embs[i].reshape(1, -1), nochange_gep_embs[i].reshape(1, -1))[0, 0] for i in range(up_gep_embs.shape[0])]
gep_similarity_df = pd.DataFrame({'factors':names,'similarity':gep_similarity}).sort_values(by='similarity').reset_index(drop=True)
gep_similarity_df['rank']=gep_similarity_df.index + 1
gep_similarity_df.to_csv(f'{gep_dir}/gep_similarity_df.csv',index=False)
gep_similarity_df

Unnamed: 0,factors,similarity,rank
0,chd4,0.924975,1
1,esco2,0.935212,2
2,cbx7,0.946856,3
3,cbx6,0.946874,4
4,cbx8,0.949367,5
...,...,...,...
986,zbtb10,0.998598,987
987,zbed4,0.998614,988
988,nkx2-5,0.998627,989
989,snapc4,0.998678,990


In [20]:
indentified_factor = gep_similarity_df[gep_similarity_df['rank']<=25]['factors'].tolist()
indentified_factor

['chd4',
 'esco2',
 'cbx7',
 'cbx6',
 'cbx8',
 'brd7',
 'hira',
 'myf5',
 'ring1',
 'neurog2',
 'kdm6b',
 'ubn1',
 'nr3c2',
 'rpa2',
 'sumo1',
 'yap1',
 'ptpn11',
 'myod1',
 'klf11',
 'phf2',
 'prkdc',
 'brdu',
 'ssrp1',
 'tead',
 'tead1']

In [21]:
gep_similarity_df[gep_similarity_df['factors']=='myod1']

Unnamed: 0,factors,similarity,rank
17,myod1,0.983192,18


We identified 25 key regulators in the transcriptome, including the notable regulator 'MYOD1'.

### Combine the regulator rankings from both chromatin accessibility and transcriptome
Final top 25 key regulators derived from average ranks across both chromatin accessibility and transcriptome

In [22]:
chrom_accessibility_path = f'{gep_dir}/../chrom_accessibility/chromatin_accessibility_similarity_df.csv'
if not os.path.exists(chrom_accessibility_path):
    raise ValueError("Please follow the tutorial for key regulators inference during cell state transition: Chromatin accessibility")
else:
    chrom_acc_similarity_df = pd.read_csv(chrom_accessibility_path)
    average_rank_df = pd.merge(gep_similarity_df, chrom_acc_similarity_df, on='factors', how='inner', suffixes=('_gep', '_chrom_acc'))
    average_rank_df['averge_rank'] = ((average_rank_df['rank_gep']+average_rank_df['rank_chrom_acc'])/2).rank().astype(int)
    average_rank_df=average_rank_df.sort_values(by='averge_rank')
    average_rank_df
    
average_rank_df[average_rank_df['factors']=='myod1']

Unnamed: 0,factors,similarity_gep,rank_gep,similarity_chrom_acc,rank_chrom_acc,averge_rank
17,myod1,0.983192,18,-0.064511,1,2


In [23]:
final_indentified_factor = average_rank_df[average_rank_df['averge_rank']<=25]['factors'].tolist()
final_indentified_factor

['myf5',
 'myod1',
 'neurog2',
 'yap1',
 'nr3c2',
 'tead1',
 'tead',
 'dux4',
 'pgbd3',
 'chd4',
 'myog',
 'hira',
 'pax3-foxo1a',
 'tbx5',
 'nr3c1',
 'snai2',
 'ss18',
 'prmt5',
 'ubn1',
 'rb1',
 'six2',
 'klf11',
 'ercc6',
 'sumo1',
 'esco2']