# Example for key regulators inference during cell state transition: chromatin accessibility

To comprehensively infer key regulators during cell state transitions, it is crucial to integrate analyses of chromatin accessibility and the transcriptome. In this tutorial, we will demonstrate how to use ChromBERT to infer key regulators involved in a specific transdifferentiation process (fibroblast to myoblast) through chromatin accessibility analysis.

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
import chrombert
import pandas as pd
import numpy as np
from torchinfo import summary
import subprocess
import torch
import lightning.pytorch as pl 
base_dir =  os.path.expanduser("~/.cache/chrombert/data") # set the base directory for storing data default


## Preprocess dataset  

In this section, we prepare raw chromatin accessibility data, including peak and BigWig files for the cell types involved in transdifferentiation. This tutorial will guide you through formatting the data for use with ChromBERT.  

To identify key regulators, we fine-tune ChromBERT to predict log2-transformed chromatin accessibility signal fold changes during transdifferentiation. This process involves preparing the transformed data. Additionally, to analyze the distances between regulator embeddings at upregulated and unchanged loci, we need to prepare and identify data for both types of loci. The preprocessing steps include:  

- Collecting peaks involved in transdifferentiation and TSS flank regions as background (extending 10 kb all TSSs).  
- Overlapping these regions with the 1 kb regions used in ChromBERT.  
- Extracting chromatin accessibility signals from BigWig files for each of these 1 kb regions.  
- Calculating log2-transformed chromatin accessibility changes.  


In [2]:

chromatin_accessibility_dir = f'{base_dir}/demo/transdifferentiation/chrom_accessibility'


In [3]:
# Download the data needed for this tutorial
if not os.path.exists(f'{chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF184KAM/@@download/ENCFF184KAM.bed.gz -O {chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed'
     subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF361BTT/@@download/ENCFF361BTT.bigWig -O {chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig'
     subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF647RNC/@@download/ENCFF647RNC.bed.gz -O {chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed'
     subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF149ERN/@@download/ENCFF149ERN.bigWig -O {chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig'
     subprocess.run(cmd, shell=True)     

#### Prepare merged peaks  

Here we merge the peaks for the cell types involved in transdifferentiation to generate a comprehensive region list.  
Then, we align the genomic coordinates of these regions with ChromBERT's 1 kb bins.

In [4]:
cmd = f"cat {chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed {chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed > {chromatin_accessibility_dir}/tmp_peak.bed"
subprocess.run(cmd, shell=True)
cmd = f"sort -k1,1 -k2,2n {chromatin_accessibility_dir}/tmp_peak.bed > {chromatin_accessibility_dir}/tmp_peak_sorted.bed"
subprocess.run(cmd, shell=True)
cmd = f"bedtools merge -i {chromatin_accessibility_dir}/tmp_peak_sorted.bed > {chromatin_accessibility_dir}/total_peak.bed"
subprocess.run(cmd, shell=True)

CompletedProcess(args='bedtools merge -i /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/tmp_peak_sorted.bed > /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/total_peak.bed', returncode=0)

In [5]:
from chrombert.scripts.chrombert_make_dataset import get_regions,process
chrom_regions = get_regions(base_dir,genome='hg38',high_resolution=False) # 1kb
total_peak_process = process(f'{chromatin_accessibility_dir}/total_peak.bed',chrom_regions,mode='region')[['chrom','start','end','build_region_index']]
len(total_peak_process),total_peak_process.head()

(396195,
   chrom   start     end  build_region_index
 0  chr1  180000  181000                  38
 1  chr1  181000  182000                  39
 2  chr1  182000  183000                  40
 3  chr1  191000  192000                  46
 4  chr1  268000  269000                  54)

#### Generate background regions  

In this study, we use TSS flank regions (within 10 kb of the transcription start site) as background samples to facilitate the fine-tuning process and identify key regulators.

In [6]:
gep_df = pd.read_csv(f'{chromatin_accessibility_dir}/../transcriptome/fibroblast_to_myoblast_expression_changes.csv')
gep_df_tss_10kb = pd.DataFrame({'chrom':gep_df['chrom'],'start':gep_df['tss']-10000,'end':gep_df['tss']+10000})
gep_df_tss_10kb
gep_df_tss_10kb.to_csv(f'{chromatin_accessibility_dir}/gep_df_tss_10kb.bed',sep='\t',index=False,header=None)
gep_df_tss_10kb.head()

Unnamed: 0,chrom,start,end
0,chr19,58343492,58363492
1,chr19,58337718,58357718
2,chr10,50875675,50895675
3,chr12,9106229,9126229
4,chr12,9055163,9075163


In [7]:
gep_df_tss_10kb_process = process(f'{chromatin_accessibility_dir}/gep_df_tss_10kb.bed',chrom_regions,mode='region').drop_duplicates(subset='build_region_index')[['chrom','start','end','build_region_index']]
len(gep_df_tss_10kb_process),gep_df_tss_10kb_process.head()

(295682,
   chrom   start     end  build_region_index
 0  chr1  815000  816000                 126
 1  chr1  816000  817000                 127
 2  chr1  817000  818000                 128
 3  chr1  818000  819000                 129
 4  chr1  819000  820000                 130)

#### Collect total regions and further process
Here we concatenate the total peak and background regions to generate the total region. Then, we extract the chromatin accessibility signals for each region and perform log2 transformation.


In [8]:
total_region_processed = pd.concat([total_peak_process,gep_df_tss_10kb_process],axis=0).drop_duplicates().reset_index(drop=True)
total_region_processed.to_csv(f'{chromatin_accessibility_dir}/total_region_processed.csv',index=False)
len(total_region_processed),total_region_processed.head()

(614861,
   chrom   start     end  build_region_index
 0  chr1  180000  181000                  38
 1  chr1  181000  182000                  39
 2  chr1  182000  183000                  40
 3  chr1  191000  192000                  46
 4  chr1  268000  269000                  54)

In [9]:
# Extract the chromatin accessibility signals

import bbi  # pip install pybbi
def bw_getSignal_bins(
    bw, regions:pd.DataFrame,name
    ):
    regions = regions.copy()
    with bbi.open(str(bw)) as bwf:
        mtx = bwf.stackup(regions["chrom"],regions["start"],regions["end"], bins=1, missing=0)
        mean= bwf.info["summary"]["mean"]
        mtx = mtx/mean
    df_signal = pd.DataFrame(data = mtx, columns = [f'{name}_signal'])
    return df_signal

fibroblast_signal = bw_getSignal_bins(bw=f'{chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig',regions=total_region_processed,name='fibroblast')
myoblast_signal = bw_getSignal_bins(bw=f'{chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig',regions=total_region_processed,name='myoblast')



In [12]:
# Prepare the log2-transformed chromatin accessibility signal fold changes data

total_region_signal_processed = pd.concat([total_region_processed,fibroblast_signal,myoblast_signal],axis=1)
total_region_signal_processed['fold_change'] = np.log2(1+total_region_signal_processed['myoblast_signal']) - np.log2(1+total_region_signal_processed['fibroblast_signal'])
total_region_signal_processed
chrom_accessibility_df = (
    total_region_signal_processed[
        ['chrom','start','end','build_region_index','fold_change','fibroblast_signal','myoblast_signal']
    ].rename(columns={'fold_change':'label'})
)
chrom_accessibility_df.to_csv(f'{chromatin_accessibility_dir}/fibroblast_to_myoblast_chrom_accessibility_changes.csv',index=False)
chrom_accessibility_df.head() 

Unnamed: 0,chrom,start,end,build_region_index,label,fibroblast_signal,myoblast_signal
0,chr1,180000,181000,38,0.091821,0.066471,0.136553
1,chr1,181000,182000,39,0.001996,2.543848,2.548754
2,chr1,182000,183000,40,0.142237,0.102401,0.216626
3,chr1,191000,192000,46,-0.678009,1.3869,0.491877
4,chr1,268000,269000,54,-0.329329,1.324022,0.849704


#### Prepare fine-tuning data  

The dataset is split into training, testing, and validation sets in an 8:1:1 ratio.  
To make the fine-tuning process simpler and faster, the data is then reduced to match these proportions.

In [13]:
train_data = chrom_accessibility_df.sample(frac=0.8,random_state=55)
test_data = chrom_accessibility_df.drop(train_data.index).sample(frac=0.5,random_state=55)
valid_data = chrom_accessibility_df.drop(train_data.index).drop(test_data.index)

train_data_sample = train_data.sample(n=80,random_state=55)
test_data_sample = test_data.sample(n=20,random_state=55)
valid_data_sample = valid_data.sample(n=20,random_state=55)


train_data_sample.to_csv(f'{chromatin_accessibility_dir}/train_sample.csv',index=False)
test_data_sample.to_csv(f'{chromatin_accessibility_dir}/test_sample.csv',index=False)
valid_data_sample.to_csv(f'{chromatin_accessibility_dir}/valid_sample.csv',index=False)
train_data_sample.head()

Unnamed: 0,chrom,start,end,build_region_index,label,fibroblast_signal,myoblast_signal
58450,chr11,16348000,16349000,300336,-0.422661,1.251264,0.679549
81335,chr12,49808000,49809000,422035,1.398294,0.95035,4.14092
346805,chr7,148855000,148856000,1856383,-0.905614,2.071367,0.639512
35488,chr1,246417000,246418000,182254,-2.684815,27.481988,3.429557
247095,chr3,132783000,132784000,1302011,3.011465,3.356764,34.132198



#### Identification of upregulated and unchanged genomic regions  

Genomic regions with a log2-transformed chromatin accessibility difference greater than 2 were classified as upregulated, indicating increased accessibility. Regions with insufficient coverage were excluded. Unchanged regions were identified by selecting the 40,000 loci with the smallest absolute fold changes in chromatin accessibility.


In [14]:
chrom_accessibility_df = pd.read_csv(f'{chromatin_accessibility_dir}/fibroblast_to_myoblast_chrom_accessibility_changes.csv')
up_data = chrom_accessibility_df[chrom_accessibility_df['label']>2]

covered_region = chrom_accessibility_df[(chrom_accessibility_df['fibroblast_signal']>0) & (chrom_accessibility_df['myoblast_signal']>0)]
covered_region['label_abs'] = np.abs(covered_region['label'])
nochange_data = covered_region.sort_values('label_abs').reset_index(drop=True).iloc[0:40000]


up_data_sample = up_data.sample(n=100,random_state=55)
nochange_data_sample = nochange_data.sample(n=100,random_state=55)


up_data_sample.to_csv(f'{chromatin_accessibility_dir}/up_data_sample.csv',index=False)
nochange_data_sample.to_csv(f'{chromatin_accessibility_dir}/nochange_data_sample.csv',index=False)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  covered_region['label_abs'] = np.abs(covered_region['label'])


## Fine-tune  

This section provides a tutorial for fine-tuning ChromBERTs to predict genome-wide changes in chromatin accessibility.  
The fine-tuning process for chromatin accessibility is similar to the general process:  
- **Dataset configuration:** Use the `general` preset dataset configuration.  
- **Model instantiation:** Use the `general` preset model configuration.  


### Configure dataset and data module

In [15]:
# Use the `general` preset dataset config
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = None, batch_size = 4, num_workers = 4)
dataset_config

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


DatasetConfig({'hdf5_file': '/home/chenqianqian/.cache/chrombert/data/hg38_6k_1kb.hdf5', 'supervised_file': None, 'kind': 'GeneralDataset', 'meta_file': '/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_meta.json', 'ignore': False, 'ignore_object': None, 'batch_size': 4, 'num_workers': 4, 'shuffle': False, 'pin_memory': True, 'perturbation': False, 'perturbation_object': None, 'perturbation_value': 0, 'prompt_kind': None, 'prompt_regulator': None, 'prompt_regulator_cache_file': None, 'prompt_celltype': None, 'prompt_celltype_cache_file': None, 'prompt_regulator_cache_pin_memory': False, 'prompt_regulator_cache_limit': 3, 'fasta_file': None, 'flank_window': 0})

In [16]:
# We use the `LitChromBERTFTDataModule` to load the data and facilitate the fine-tuning process.
data_module = chrombert.LitChromBERTFTDataModule(
    config = dataset_config, 
    train_params = {'supervised_file': f'{chromatin_accessibility_dir}/train_sample.csv'}, 
    val_params = {'supervised_file':f'{chromatin_accessibility_dir}/valid_sample.csv'}, 
    test_params = {'supervised_file':f'{chromatin_accessibility_dir}/test_sample.csv'}
)
data_module.setup()

### Configure model and instantiation

In [17]:
# we use `general` preset model config 
model_config = chrombert.get_preset_model_config("general")
model_config

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt


ChromBERTFTConfig:
{
    "genome": "hg38",
    "task": "general",
    "dim_output": 1,
    "mtx_mask": "/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_mask_matrix.tsv",
    "dropout": 0.1,
    "pretrain_ckpt": "/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt",
    "finetune_ckpt": null,
    "ignore": false,
    "ignore_index": [
        null,
        null
    ],
    "gep_flank_window": 4,
    "gep_parallel_embedding": false,
    "gep_gradient_checkpoint": false,
    "gep_zero_inflation": false,
    "prompt_kind": "cistrome",
    "prompt_dim_external": 512,
    "dnabert2_ckpt": null
}

In [18]:
model = model_config.init_model()
model.freeze_pretrain(2) ### freeze chrombert 6 transformer blocks during fine-tuning 
summary(model)

use organisim hg38; max sequence length is 6391


  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))


Layer (type:depth-idx)                                  Param #
ChromBERTGeneral                                        --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               --
│    │    └─TokenEmbedding: 3-1                         (7,680)
│    │    └─PositionalEmbedding: 3-2                    (4,909,056)
│    │    └─Dropout: 3-3                                --
│    └─ModuleList: 2-2                                  --
│    │    └─EncoderTransformerBlock: 3-4                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-5                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-6                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-7                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-8                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-9                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-10               6,497,280
│    │    └─EncoderTransformerBlock

### Configure training parameters and fine-tune  

The model is fine-tuned using PyTorch Lightning with a straightforward parameter configuration. To save time, fine-tuning is performed on a smaller dataset.  

*Note:* Due to the stochastic nature of the tuning process, results may vary. For improved performance, consider increasing the number of epochs and expanding the dataset size.  


In [19]:
train_config = chrombert.finetune.TrainConfig(kind='regression',        
                                              loss='rmse',
                                              max_epochs=2,
                                              accumulate_grad_batches=2,
                                              val_check_interval=2,
                                              limit_val_batches=10,
                                              tag='chrom_accessibility')
train_config
train_module = train_config.init_pl_module(model)



In [20]:
callback_ckpt = pl.callbacks.ModelCheckpoint(monitor = f"{train_config.tag}_validation/{train_config.loss}", mode = "min")
trainer = pl.Trainer(
    max_epochs=train_config.max_epochs,
    log_every_n_steps=1, 
    limit_val_batches = train_config.limit_val_batches,
    val_check_interval = train_config.val_check_interval,
    accelerator="gpu", 
    accumulate_grad_batches= train_config.accumulate_grad_batches, 
    fast_dev_run=False, 
    precision="bf16-mixed",
    strategy="auto",
    callbacks=[
        pl.callbacks.LearningRateMonitor(),
        callback_ckpt,   
    ],
    logger=pl.loggers.TensorBoardLogger("lightning_logs", name='chrom_accessibility'),
    )
trainer.fit(train_module,data_module)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: lightning_logs/chrom_accessibility
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | model | ChromBERTGeneral | 62.8 M | train
---------------------------------------------------
18.9 M    Trainable params
43.9 M    Non-trainable params
62.8 M    Total params
251.095   Total estimated mode

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/chenqianqian/.conda/envs/flash23_torch20/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.



## Evaluate the fine-tuned model  

ChromBERT has been successfully fine-tuned! You can evaluate the model directly using `pl_module.model`. However, note that due to flash-attention settings, the dropout probability cannot be adjusted by `model.eval()`, which may introduce some randomness to the output.  

So, to ensure consistent results,it's suggested to save the fine-tuned checkpoint and reload it into the original model configuration for accurate evaluation.  

In addition, we use only downsampled test data for evaluation in this tutorial to save time.

### Evaluate the model fine-tuned with limited data

We first load the fine-tuned model and evaluate it on the downsampled test data here.


In [21]:
import glob
chrom_accessibility_ft_ckpt = os.path.abspath(glob.glob('./lightning_logs/chrom_accessibility/version_*/checkpoints/*.ckpt')[0])
chrom_accessibility_ft_ckpt
model_config = chrombert.get_preset_model_config("general")
ft_model = model_config.init_model(finetune_ckpt = chrom_accessibility_ft_ckpt,dropout=0)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391


  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))


Loading checkpoint from /shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public_2/examples/tutorials/lightning_logs/chrom_accessibility/version_0/checkpoints/epoch=1-step=12.ckpt


  new_state = torch.load(ckpt)


Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters


In [22]:
from tqdm import tqdm
import torchmetrics as tm
dl = data_module.test_dataloader()
ft_model = ft_model.cuda().eval()
with torch.no_grad():
    y_preds = []
    y_labels = []
    for batch in tqdm(dl,total=len(dl)):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch).cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)

100%|██████████| 5/5 [00:01<00:00,  2.74it/s]

{'pearsonr': tensor(0.1227), 'spearmanr': tensor(0.2571), 'mse': tensor(0.7629), 'mae': tensor(0.5203), 'r2': tensor(-0.0613)}





### Performance of the Fine-Tuned Model  

The model fine-tuned with limited data demonstrated suboptimal performance during evaluation. To address this, we provided a checkpoint fine-tuned on the entire dataset and will evaluate its performance here.  



In [28]:
chrom_accessibility_ft_ckpt = f'{chromatin_accessibility_dir}/chrom_accessibility_fibroblast_to_myoblast.ckpt'
model_config = chrombert.get_preset_model_config("general")
ft_model = model_config.init_model(finetune_ckpt = chrom_accessibility_ft_ckpt,dropout=0)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391


  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))


Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters


  new_state = torch.load(ckpt)


In [29]:
from tqdm import tqdm
import torchmetrics as tm
dl = data_module.test_dataloader()
ft_model = ft_model.cuda().eval()
with torch.no_grad():
    y_preds = []
    y_labels = []
    for batch in tqdm(dl,total=len(dl)):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch).cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:01<00:00,  2.80it/s]

{'pearsonr': tensor(0.9576), 'spearmanr': tensor(0.8451), 'mse': tensor(0.0629), 'mae': tensor(0.2033), 'r2': tensor(0.9125)}





## Infer key regulators  

Using the fine-tuned model, we analyze embedding similarities between upregulated and unchanged genomic regions. Lower embedding similarity is hypothesized to indicate a greater functional shift, highlighting the potential role of key regulators in cell state transitions.  

To save time, we perform the analysis on a downsampled set of upregulated and unchanged genomic regions.  

*Note:* Only the selected factors are considered for this analysis, excluding histone modifications and chromatin accessibility.  


In [30]:
# Load the fine-tuned model
model_tuned = chrombert.get_preset_model_config(
    "general", 
    dropout = 0,
    finetune_ckpt = f'{chromatin_accessibility_dir}/chrom_accessibility_fibroblast_to_myoblast.ckpt').init_model() # use absolute path here, to avoid mixing of preset

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt
use organisim hg38; max sequence length is 6391


  ck = torch.load(ckpt_path, map_location=torch.device('cpu'))


Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt


  new_state = torch.load(ckpt)


Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters


In [31]:
# Get the embedding manager
model_emb = model_tuned.get_embedding_manager().cuda()


### Gather regulator embedding in upregulated genomic regions

In [32]:
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = f'{chromatin_accessibility_dir}/up_data_sample.csv', batch_size = 32, num_workers = 4)
up_chrom_acc_embs=[]
dl = dataset_config.init_dataloader()
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        up_chrom_acc_embs.append(emb)
up_chrom_acc_embs = torch.cat(up_chrom_acc_embs,dim=0)
up_chrom_acc_embs.shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


100%|██████████| 4/4 [00:05<00:00,  1.34s/it]


torch.Size([100, 1073, 768])

### Gather regulator embedding in nochange genomic regions

In [33]:
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = f'{chromatin_accessibility_dir}/nochange_data_sample.csv', batch_size = 32, num_workers = 4)
nochange_chrom_acc_embs=[]
dl = dataset_config.init_dataloader()
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        nochange_chrom_acc_embs.append(emb)
nochange_chrom_acc_embs = torch.cat(nochange_chrom_acc_embs,dim=0)
nochange_chrom_acc_embs.shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


100%|██████████| 4/4 [00:05<00:00,  1.39s/it]


torch.Size([100, 1073, 768])

In [34]:
with open(os.path.join(base_dir, "config","hg38_6k_factors_list.txt"),"r") as f:
    factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]


indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
up_chrom_acc_embs = up_chrom_acc_embs.mean(axis=0)[indices]
nochange_chrom_acc_embs = nochange_chrom_acc_embs.mean(axis=0)[indices]

In [35]:
up_chrom_acc_embs.shape, nochange_chrom_acc_embs.shape

(torch.Size([991, 768]), torch.Size([991, 768]))

### Analyze the embedding similarity between upregulated and unchanged genomic regions

We calculate the cosine similarity between the embeddings of upregulated and unchanged genomic regions to identify key regulators.


In [36]:
from sklearn.metrics.pairwise import cosine_similarity
chrom_acc_similarity = [cosine_similarity(up_chrom_acc_embs[i].reshape(1, -1), nochange_chrom_acc_embs[i].reshape(1, -1))[0, 0] for i in range(up_chrom_acc_embs.shape[0])]
chrom_acc_similarity_df = pd.DataFrame({'factors':names,'similarity':chrom_acc_similarity}).sort_values(by='similarity').reset_index(drop=True)
chrom_acc_similarity_df['rank']=chrom_acc_similarity_df.index + 1
chrom_acc_similarity_df.to_csv(f'{chromatin_accessibility_dir}/chromatin_accessibility_similarity_df.csv',index=False)
chrom_acc_similarity_df

Unnamed: 0,factors,similarity,rank
0,myod1,-0.064511,1
1,myf5,0.080093,2
2,myog,0.101807,3
3,pax3-foxo1a,0.107334,4
4,tead1,0.214081,5
...,...,...,...
986,znf250,0.984589,987
987,hoxa1,0.984709,988
988,zbtb10,0.985032,989
989,znf706,0.985040,990


In [37]:
chrom_acc_similarity_df[chrom_acc_similarity_df['factors']=='myod1']

Unnamed: 0,factors,similarity,rank
0,myod1,-0.064511,1


We identified 25 key regulators in the chromatin accessibility, including the notable regulator 'MYOD1'.

### Combine the regulator rankings from both chromatin accessibility and transcriptome
Final top 25 key regulators derived from average ranks across both chromatin accessibility and transcriptome.

In [38]:
gep_similarity_path = f'{chromatin_accessibility_dir}/../transcriptome/gep_similarity_df.csv'
if not os.path.exists(gep_similarity_path):
    raise ValueError("Please follow the tutorial for key regulators inference during cell state transition: Transcriptome")
else:
    gep_similarity_df = pd.read_csv(gep_similarity_path)
    average_rank_df = pd.merge(gep_similarity_df, chrom_acc_similarity_df, on='factors', how='inner', suffixes=('_gep', '_chrom_acc'))
    average_rank_df['averge_rank'] = ((average_rank_df['rank_gep']+average_rank_df['rank_chrom_acc'])/2).rank().astype(int)
    average_rank_df=average_rank_df.sort_values(by='averge_rank')
    average_rank_df
    
average_rank_df[average_rank_df['factors']=='myod1']

Unnamed: 0,factors,similarity_gep,rank_gep,similarity_chrom_acc,rank_chrom_acc,averge_rank
17,myod1,0.983192,18,-0.064511,1,2


In [39]:
final_indentified_factor = average_rank_df[average_rank_df['averge_rank']<=25]['factors'].tolist()
final_indentified_factor

['myf5',
 'myod1',
 'neurog2',
 'yap1',
 'nr3c2',
 'tead1',
 'tead',
 'dux4',
 'pgbd3',
 'chd4',
 'myog',
 'hira',
 'pax3-foxo1a',
 'tbx5',
 'nr3c1',
 'snai2',
 'ss18',
 'prmt5',
 'ubn1',
 'rb1',
 'six2',
 'klf11',
 'ercc6',
 'sumo1',
 'esco2']