# Tutorial for uncovers cellular dynamics using ChromBERT
In this tutorial, we will demonstrate how to use ChromBERT to uncover cellular dynamics during a specific transdifferentiation process(Fibroblast to myoblast). The fine-tuned model was trained separately on chromatin accessibility data and expression data. To follow this tutorial, you need to have the checkpoint (ckpt) files downloaded (see the README for details).


## Preprocessing dataset
You need to prepare two types of data: transcriptome data (including TSS and TPM for gene) and chromatin accessibility data (including peak and BigWig) for the cell type involved in transdifferentiation. This tutorial will show you how to transform the data into the format required by ChromBERT.

In [1]:
import chrombert
import pandas as pd
import numpy as np
import os
from torchinfo import summary
import subprocess
base_dir = '/home/chenqianqian/.cache/chrombert/data' ### to_path_chrombert/data

### Preprocessing transcriptome dataset
To fine-tune on transcriptome data for predicting log1p-transformed fold change during transdifferentiation, you need to prepare the log1p-transformed fold change data. Additionally, you need to obtain the ChromBERT 1kb region bins corresponding to the TSS of the relevant genes.

In [2]:
gep_dir = f'{base_dir}/demo/transdifferentiation/transcriptome'
fibroblast_exp = pd.read_csv(f'{gep_dir}/fibroblast_expression.csv')
myoblast_exp = pd.read_csv(f'{gep_dir}/myoblast_expression.csv')
myoblast_exp.head(), fibroblast_exp.head()

(   chrom       tss          gene_id        tpm
 0  chr19  58353492  ENSG00000121410  22.236894
 1  chr19  58347718  ENSG00000268895   9.317134
 2  chr10  50885675  ENSG00000148584   0.000000
 3  chr12   9116229  ENSG00000175899   0.993828
 4  chr12   9065163  ENSG00000245105   0.124228,
    chrom       tss          gene_id        tpm
 0  chr19  58353492  ENSG00000121410  12.774133
 1  chr19  58347718  ENSG00000268895   2.939181
 2  chr10  50885675  ENSG00000148584   0.000000
 3  chr12   9116229  ENSG00000175899   0.226091
 4  chr12   9065163  ENSG00000245105   0.226091)

#### Prepare the log1p-transformed fold change data

In [3]:
from chrombert.scripts.chrombert_make_dataset import get_regions
merge_exp = pd.merge(fibroblast_exp,myoblast_exp,left_on=['chrom','tss','gene_id'],right_on=['chrom','tss','gene_id'],suffixes=['_fibroblast','_myoblast'])
merge_exp['fold_change']= np.log1p(merge_exp['tpm_myoblast']) - np.log1p(merge_exp['tpm_fibroblast'])
merge_exp['start'] = merge_exp['tss']//1000 * 1000
merge_exp['end'] = (merge_exp['tss']//1000 + 1) * 1000
foldchange_exp = merge_exp [['chrom','start','end','tss','gene_id','fold_change']]

chrom_regions = get_regions('~/.cache/chrombert/data',genome='hg38',high_resolution=False) # 1kb
chrom_regions
chrom_regions_df = pd.read_csv(chrom_regions,sep='\t',names=['chrom','start','end','build_region_index'])
chrom_regions_df
merge_region = pd.merge(foldchange_exp,chrom_regions_df,left_on=['chrom','start','end'],right_on=['chrom','start','end'],how='inner')[['chrom','start','end','build_region_index','fold_change','tss','gene_id']]
gep_df = merge_region.rename(columns={'fold_change':'label'})
gep_df.to_csv(f'{gep_dir}/fibroblast_to_myoblast_expression_changes.csv',index=False)
gep_df.head()

Unnamed: 0,chrom,start,end,build_region_index,label,tss,gene_id
0,chr19,58353000,58354000,917950,0.522949,58353492,ENSG00000121410
1,chr19,58347000,58348000,917944,0.962833,58347718,ENSG00000268895
2,chr10,50885000,50886000,221904,0.0,50885675,ENSG00000148584
3,chr12,9116000,9117000,393001,0.486225,9116229,ENSG00000175899
4,chr12,9065000,9066000,392961,-0.086734,9065163,ENSG00000245105


#### Preprocessing the fine-tuning data of transcriptome 
splitting it into training, testing, and validation sets with an 8:1:1 ratio and Downsample the data to test the fine-tuning process

In [4]:
train_data = gep_df.sample(frac=0.8,random_state=55)
test_data = gep_df.drop(train_data.index).sample(frac=0.5,random_state=55)
valid_data = gep_df.drop(train_data.index).drop(test_data.index)
train_data_sample = train_data.sample(n=80,random_state=55)
test_data_sample = test_data.sample(n=20,random_state=55)
valid_data_sample = valid_data.sample(n=20,random_state=55)
train_data_sample.to_csv(f'{gep_dir}/train_sample.csv',index=False)
test_data_sample.to_csv(f'{gep_dir}/test_sample.csv',index=False)
valid_data_sample.to_csv(f'{gep_dir}/valid_sample.csv',index=False)
train_data_sample.head()

Unnamed: 0,chrom,start,end,build_region_index,label,tss,gene_id
4496,chr6,138795000,138796000,1719247,0.0,138795911,ENSG00000203734
13237,chr8,96261000,96262000,1933979,-0.351699,96261902,ENSG00000156471
17079,chr22,29555000,29556000,1186796,-0.565257,29555216,ENSG00000100296
4118,chr8,1622000,1623000,1866390,0.0,1622417,ENSG00000253267
13482,chr7,66682000,66683000,1792879,0.402664,66682164,ENSG00000154710


#### Prepare the upregulated genes and unchanged genes for downstream analysis.

In [5]:
up_data = gep_df[gep_df['label']>1]
nochange_data = gep_df[(gep_df['label']>-0.5) & (gep_df['label']<0.5)]

up_data_sample = up_data.sample(n=100,random_state=55)
nochange_data_sample = nochange_data.sample(n=100,random_state=55)


up_data_sample.to_csv(f'{gep_dir}/up_data_sample.csv',index=False)
nochange_data_sample.to_csv(f'{gep_dir}/nochange_data_sample.csv',index=False)

### Preprocessing chromatin accessibility dataset

Consider the total peak regions and gene promoter regions (extending 10 kb from the transcription start site) as background.Overlap these regions with 1 kb regions used in ChromBERT.
Extract the chromatin accessibility signals from BigWig files for each of these 1 kb regions. Calculate log2-transformed chromatin accessibility changes. 

In [6]:

chromatin_accessibility_dir = f'{base_dir}/demo/transdifferentiation/chrom_accessibility'


In [7]:
chromatin_accessibility_dir

'/home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility'

In [8]:
if not os.path.exists(f'{chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF184KAM/@@download/ENCFF184KAM.bed.gz -O {chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed'
     subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF361BTT/@@download/ENCFF361BTT.bigWig -O {chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig'
     subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF647RNC/@@download/ENCFF647RNC.bed.gz -O {chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed'
     subprocess.run(cmd, shell=True)
if not os.path.exists(f'{chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig'):
     cmd = f'wget https://www.encodeproject.org/files/ENCFF149ERN/@@download/ENCFF149ERN.bigWig -O {chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig'
     subprocess.run(cmd, shell=True)     

#### Merge peak during transdifferentiation
merge peak for the cell type involved in transdifferentiation to generate total peak file and Overlap total peak file with chromBERT 1kb bins

In [9]:
cmd = f"cat {chromatin_accessibility_dir}/fibroblast_ENCFF184KAM_peak.bed {chromatin_accessibility_dir}/myoblast_ENCFF647RNC_peak.bed > {chromatin_accessibility_dir}/tmp_peak.bed"
subprocess.run(cmd, shell=True)
cmd = f"bedtools merge -i {chromatin_accessibility_dir}/tmp_peak.bed > {chromatin_accessibility_dir}/total_peak.bed"
subprocess.run(cmd, shell=True)

Error: Sorted input specified, but the file /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/tmp_peak.bed has the following out of order record
chr1	180791	180871	.	0	.	0.0397844	-1	-1	75


CompletedProcess(args='bedtools merge -i /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/tmp_peak.bed > /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/total_peak.bed', returncode=1)

In [10]:
from chrombert.scripts.chrombert_make_dataset import get_regions,process
chrom_regions = get_regions(base_dir,genome='hg38',high_resolution=False) # 1kb
total_peak_process = process(f'{chromatin_accessibility_dir}/total_peak.bed',chrom_regions,mode='region')[['chrom','start','end','build_region_index']]
len(total_peak_process),total_peak_process.head()

(249604,
   chrom   start     end  build_region_index
 0  chr1  181000  182000                  39
 1  chr1  191000  192000                  46
 2  chr1  268000  269000                  54
 3  chr1  729000  730000                  93
 4  chr1  778000  779000                 102)

#### Generate background regions
To add some background regions, you need to include regions near the promoter TSS (within 10kb) as background samples and overlap these background regions with ChromBERT 1kb bins.

In [11]:
gep_df = pd.read_csv(f'{gep_dir}/fibroblast_to_myoblast_expression_changes.csv')
gep_df_tss_10kb = pd.DataFrame({'chrom':gep_df['chrom'],'start':gep_df['tss']-10000,'end':gep_df['tss']+10000})
gep_df_tss_10kb
gep_df_tss_10kb.to_csv(f'{chromatin_accessibility_dir}/gep_df_tss_10kb.bed',sep='\t',index=False,header=None)
gep_df_tss_10kb.head()

Unnamed: 0,chrom,start,end
0,chr19,58343492,58363492
1,chr19,58337718,58357718
2,chr10,50875675,50895675
3,chr12,9106229,9126229
4,chr12,9055163,9075163


In [12]:
gep_df_tss_10kb_process = process(f'{chromatin_accessibility_dir}/gep_df_tss_10kb.bed',chrom_regions,mode='region').drop_duplicates(subset='build_region_index')[['chrom','start','end','build_region_index']]
len(gep_df_tss_10kb_process),gep_df_tss_10kb_process.head()

(295682,
   chrom   start     end  build_region_index
 0  chr1  815000  816000                 126
 1  chr1  816000  817000                 127
 2  chr1  817000  818000                 128
 3  chr1  818000  819000                 129
 4  chr1  819000  820000                 130)

#### Merge peak and background region to generate total region
Generate the total region by concatenating the total peak and negative regions.

In [13]:
total_region_processed = pd.concat([total_peak_process,gep_df_tss_10kb_process],axis=0).drop_duplicates().reset_index(drop=True)
total_region_processed.to_csv(f'{chromatin_accessibility_dir}/total_region_processed.csv',index=False)
len(total_region_processed),total_region_processed.head()

(491615,
   chrom   start     end  build_region_index
 0  chr1  181000  182000                  39
 1  chr1  191000  192000                  46
 2  chr1  268000  269000                  54
 3  chr1  729000  730000                  93
 4  chr1  778000  779000                 102)

#### Extract the chromatin accessibility signals

In [14]:
import bbi  # pip install pybbi
def bw_getSignal_bins(
    bw, regions:pd.DataFrame,name
    ):
    regions = regions.copy()
    with bbi.open(str(bw)) as bwf:
        mtx = bwf.stackup(regions["chrom"],regions["start"],regions["end"], bins=1, missing=0)
        mean= bwf.info["summary"]["mean"]
        mtx = mtx/mean
    df_signal = pd.DataFrame(data = mtx, columns = [f'{name}_signal'])
    return df_signal

In [15]:
fibroblast_signal = bw_getSignal_bins(bw=f'{chromatin_accessibility_dir}/fibroblast_ENCFF361BTT_signal.bigwig',regions=total_region_processed,name='fibroblast')


In [16]:
myoblast_signal = bw_getSignal_bins(bw=f'{chromatin_accessibility_dir}/myoblast_ENCFF149ERN_signal.bigwig',regions=total_region_processed,name='myoblast')

#### Calculate the log2-transformed fold change

In [17]:
total_region_signal_processed = pd.concat([total_region_processed,fibroblast_signal,myoblast_signal],axis=1)
total_region_signal_processed['fold_change'] = np.log2(1+total_region_signal_processed['myoblast_signal']) - np.log2(1+total_region_signal_processed['fibroblast_signal'])
total_region_signal_processed
chrom_accessibility_df = total_region_signal_processed[['chrom','start','end','build_region_index','fold_change','fibroblast_signal','myoblast_signal']].rename(columns={'fold_change':'label'})
chrom_accessibility_df.to_csv(f'{chromatin_accessibility_dir}/fibroblast_to_myoblast_chrom_accessibility_changes.csv',index=False)
chrom_accessibility_df.head()

Unnamed: 0,chrom,start,end,build_region_index,label,fibroblast_signal,myoblast_signal
0,chr1,181000,182000,39,0.001996,2.543848,2.548754
1,chr1,191000,192000,46,-0.678009,1.3869,0.491877
2,chr1,268000,269000,54,-0.329329,1.324022,0.849704
3,chr1,729000,730000,93,-0.49194,0.470683,0.045756
4,chr1,778000,779000,102,0.098767,37.696015,40.437941


#### Preprocessing the fine-tuning data of chromatin accessibility 
Splitting it into training, testing, and validation sets with an 8:1:1 ratio and Downsample the data to test the fine-tuning process


In [18]:
train_data = chrom_accessibility_df.sample(frac=0.8,random_state=55)
test_data = chrom_accessibility_df.drop(train_data.index).sample(frac=0.5,random_state=55)
valid_data = chrom_accessibility_df.drop(train_data.index).drop(test_data.index)

train_data_sample = train_data.sample(n=80,random_state=55)
test_data_sample = test_data.sample(n=20,random_state=55)
valid_data_sample = valid_data.sample(n=20,random_state=55)


train_data_sample.to_csv(f'{chromatin_accessibility_dir}/train_sample.csv',index=False)
test_data_sample.to_csv(f'{chromatin_accessibility_dir}/test_sample.csv',index=False)
valid_data_sample.to_csv(f'{chromatin_accessibility_dir}/valid_sample.csv',index=False)
train_data_sample.head()

Unnamed: 0,chrom,start,end,build_region_index,label,fibroblast_signal,myoblast_signal
267005,chr1,160963000,160964000,111173,0.25308,0.235342,0.472217
458923,chr7,73835000,73836000,1798466,0.044447,0.034134,0.066489
50428,chr12,41932000,41933000,415524,-0.245291,1.433608,1.053103
250859,chr1,8210000,8211000,6478,0.191804,0.302711,0.487945
370974,chr19,50858000,50859000,912107,0.274903,0.195819,0.446836


#### Prepare the upregulated regions and unchanged regions for downstream analysis. 

Regions showing a log2-transformed chrom accessibility difference greater than 2 were marked as having increased chromatin accessibility, Uncovered regions were excluded, with the top 40,000 regions showing the absolute fold changes selected as those with unchanged accessibility.

In [19]:
chrom_accessibility_df = pd.read_csv(f'{chromatin_accessibility_dir}/fibroblast_to_myoblast_chrom_accessibility_changes.csv')
up_data = chrom_accessibility_df[chrom_accessibility_df['label']>2]

covered_region = chrom_accessibility_df[(chrom_accessibility_df['fibroblast_signal']>0) & (chrom_accessibility_df['myoblast_signal']>0)]
covered_region['label_abs'] = np.abs(covered_region['label'])
nochange_data = covered_region.sort_values('label_abs').reset_index(drop=True).iloc[0:40000]


up_data_sample = up_data.sample(n=100,random_state=55)
nochange_data_sample = nochange_data.sample(n=100,random_state=55)


up_data_sample.to_csv(f'{chromatin_accessibility_dir}/up_data_sample.csv',index=False)
nochange_data_sample.to_csv(f'{chromatin_accessibility_dir}/nochange_data_sample.csv',index=False)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  covered_region['label_abs'] = np.abs(covered_region['label'])


## Fine-tune

In this section, we provide a tutorial for fine-tuning ChromBERTs to predict genome-wide changes in transcriptome and chromatin accessibility. 
The fine-tuning process for chromatin accessibility is general. However, the objectives for transcriptome changes differed due to the potential uncertainty in the influence of regions adjacent to transcription start sites (TSSs) on gene expression,  with a few key modifications:
- Dataset config Preparation: We use the `multi_flank_window` preset dataset config and set four `flank_window` parameter
- Model Instantiation: We use the `gep` preset model config and set four `gep_flank_window` parameter
Let's get started!

In [1]:
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
import torch
import chrombert
import pandas as pd
from torchinfo import summary
import lightning.pytorch as pl 


In [2]:
base_dir = '/home/chenqianqian/.cache/chrombert/data'

### Fine-tuning ChromBERTs to predict genome-wide changes in transcriptome 

#### Set dataset config and data_module

We use the `multi_flank_window` preset dataset config and set four `flank_window` parameter

In [3]:
dataset_config = chrombert.get_preset_dataset_config("multi_flank_window",supervised_file = None, batch_size = 4, num_workers = 4,flank_window=4)
dataset_config

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


DatasetConfig({'hdf5_file': '/home/chenqianqian/.cache/chrombert/data/hg38_6k_1kb.hdf5', 'supervised_file': None, 'kind': 'MultiFlankwindowDataset', 'meta_file': '/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_meta.json', 'ignore': False, 'ignore_object': None, 'batch_size': 4, 'num_workers': 4, 'shuffle': False, 'perturbation': False, 'perturbation_object': None, 'perturbation_value': 0, 'prompt_kind': None, 'prompt_regulator': None, 'prompt_regulator_cache_file': None, 'prompt_celltype': None, 'prompt_celltype_cache_file': None, 'fasta_file': None, 'flank_window': 4})

In [4]:
gep_dir = f'{base_dir}/demo/transdifferentiation/transcriptome'

In [5]:
data_module = chrombert.LitChromBERTFTDataModule(
    config = dataset_config, 
    train_params = {'supervised_file': f'{gep_dir}/train_sample.csv'}, 
    val_params = {'supervised_file':f'{gep_dir}/valid_sample.csv'}, 
    test_params = {'supervised_file':f'{gep_dir}/test_sample.csv'}
)
data_module.setup()

#### Set model config and model loading
We use the `gep` preset model config and set four `gep_flank_window` parameter

In [6]:
model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4)
model_config

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt


ChromBERTFTConfig(genome='hg38', task='gep', dim_output=1, mtx_mask='/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_mask_matrix.tsv', dropout=0.1, pretrain_ckpt='/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt', finetune_ckpt=None, ignore=False, ignore_index=(None, None), gep_flank_window=4, gep_parallel_embedding=False, gep_gradient_checkpoint=False, gep_zero_inflation=True, prompt_kind='cistrome', prompt_dim_external=512, dnabert2_ckpt=None)

In [7]:
model = model_config.init_model()
model.freeze_pretrain(2) ### freeze chrombert 6 transformer blocks
summary(model)

use organisim hg38; max sequence length including cls is 6392


Layer (type:depth-idx)                                       Param #
ChromBERTGEP                                                 --
├─PoolFlankWindow: 1-1                                       --
│    └─ChromBERT: 2-1                                        --
│    │    └─BERTEmbedding: 3-1                               (4,916,736)
│    │    └─ModuleList: 3-2                                  51,978,240
├─GepHeader: 1-2                                             --
│    └─CistromeEmbeddingManager: 2-2                         --
│    └─Conv2d: 2-3                                           769
│    └─ReLU: 2-4                                             --
│    └─ResidualBlock: 2-5                                    --
│    │    └─Linear: 3-3                                      1,090,560
│    │    └─Linear: 3-4                                      1,049,600
│    │    └─LayerNorm: 3-5                                   2,048
│    │    └─Linear: 3-6                                      1,0

#### Set train config and finetune

We fine-tune the model using PyTorch Lightning. A simple configuration is created to process parameters, and tuning is performed on a limited dataset to save time.
Note: The tuning process is random, so results may vary. To achieve the best results, consider increasing the number of epochs and the size of the dataset used


In [8]:
train_config = chrombert.finetune.TrainConfig(kind='zero_inflation',        
loss='zero_inflation',
max_epochs=2,
accumulate_grad_batches=2,
val_check_interval=2,
limit_val_batches=10,
tag='gep')
pl_module = train_config.init_pl_module(model) # wrap model with PyTorch Lightning module
type(pl_module)



chrombert.finetune.train.pl_module.ZeroInflationPLModule

Then we start tuning!   
The trainer will save logs in a format compatible with TensorBoard, and the checkpoint with the lowest validation loss will be saved during the process.



In [9]:
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # for tensorboard compatibility
os.environ["CUDA_VISIBLE_DEVICES"]='0'
callback_ckpt = pl.callbacks.ModelCheckpoint(monitor = f"{train_config.tag}_validation/{train_config.loss}", mode = "min")
trainer = pl.Trainer(
    max_epochs=train_config.max_epochs,
    log_every_n_steps=1, 
    limit_val_batches = train_config.limit_val_batches,
    val_check_interval = train_config.val_check_interval,
    accelerator="gpu", 
    accumulate_grad_batches= train_config.accumulate_grad_batches, 
    fast_dev_run=False, 
    precision="bf16-mixed",
    strategy="auto",
    callbacks=[
        pl.callbacks.LearningRateMonitor(),
        callback_ckpt,   
    ],
    logger=pl.loggers.TensorBoardLogger("lightning_logs", name='gep'))
trainer.fit(pl_module,data_module)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: lightning_logs/gep
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name  | Type         | Params | Mode 
-----------------------------------------------
0 | model | ChromBERTGEP | 62.8 M | train
-----------------------------------------------
18.9 M    Trainable params
43.9 M    Non-trainable params
62.8 M    Total params
251.022   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/chenqianqian/.conda/envs/demo/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


### Fine-tuning ChromBERTs to predict genome-wide changes in chromatin accessibility 

#### Set data config and data module
We use the `general` preset dataset config

In [10]:
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = None, batch_size = 4, num_workers = 4)
dataset_config

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


DatasetConfig({'hdf5_file': '/home/chenqianqian/.cache/chrombert/data/hg38_6k_1kb.hdf5', 'supervised_file': None, 'kind': 'GeneralDataset', 'meta_file': '/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_meta.json', 'ignore': False, 'ignore_object': None, 'batch_size': 4, 'num_workers': 4, 'shuffle': False, 'perturbation': False, 'perturbation_object': None, 'perturbation_value': 0, 'prompt_kind': None, 'prompt_regulator': None, 'prompt_regulator_cache_file': None, 'prompt_celltype': None, 'prompt_celltype_cache_file': None, 'fasta_file': None, 'flank_window': 0})

In [13]:
chrom_accessibility_dir = f'{base_dir}/demo/transdifferentiation/chrom_accessibility'

In [14]:
data_module = chrombert.LitChromBERTFTDataModule(
    config = dataset_config, 
    train_params = {'supervised_file': f'{chrom_accessibility_dir}/train_sample.csv'}, 
    val_params = {'supervised_file':f'{chrom_accessibility_dir}/valid_sample.csv'}, 
    test_params = {'supervised_file':f'{chrom_accessibility_dir}/test_sample.csv'}
)
data_module.setup()

#### Set model config and model loading

we use `general` preset model config

In [15]:
model_config = chrombert.get_preset_model_config("general")
model_config

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt


ChromBERTFTConfig(genome='hg38', task='general', dim_output=1, mtx_mask='/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_mask_matrix.tsv', dropout=0.1, pretrain_ckpt='/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt', finetune_ckpt=None, ignore=False, ignore_index=(None, None), gep_flank_window=4, gep_parallel_embedding=False, gep_gradient_checkpoint=False, gep_zero_inflation=True, prompt_kind='cistrome', prompt_dim_external=512, dnabert2_ckpt=None)

In [16]:
model = model_config.init_model()
model.freeze_pretrain(2) ### freeze chrombert 6 transformer blocks
summary(model)

use organisim hg38; max sequence length including cls is 6392


Layer (type:depth-idx)                                  Param #
ChromBERTGeneral                                        --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               --
│    │    └─TokenEmbedding: 3-1                         (7,680)
│    │    └─PositionalEmbedding: 3-2                    (4,909,056)
│    │    └─Dropout: 3-3                                --
│    └─ModuleList: 2-2                                  --
│    │    └─EncoderTransformerBlock: 3-4                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-5                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-6                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-7                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-8                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-9                (6,497,280)
│    │    └─EncoderTransformerBlock: 3-10               6,497,280
│    │    └─EncoderTransformerBlock

#### Set train config and finetune

We fine-tune the model using PyTorch Lightning. A simple configuration is created to process parameters, and tuning is performed on a limited dataset to save time.
Note: The tuning process is random, so results may vary. To achieve the best results, consider increasing the number of epochs and the size of the dataset used

In [17]:
train_config = chrombert.finetune.TrainConfig(kind='regression',        
                                              loss='rmse',
                                              max_epochs=2,
                                              accumulate_grad_batches=2,
                                              val_check_interval=2,
                                              limit_val_batches=10,
                                              tag='chrom_accessibility')
train_config
train_module = train_config.init_pl_module(model)



In [18]:
callback_ckpt = pl.callbacks.ModelCheckpoint(monitor = f"{train_config.tag}_validation/{train_config.loss}", mode = "min")
trainer = pl.Trainer(
    max_epochs=train_config.max_epochs,
    log_every_n_steps=1, 
    limit_val_batches = train_config.limit_val_batches,
    val_check_interval = train_config.val_check_interval,
    accelerator="gpu", 
    accumulate_grad_batches= train_config.accumulate_grad_batches, 
    fast_dev_run=False, 
    precision="bf16-mixed",
    strategy="auto",
    callbacks=[
        pl.callbacks.LearningRateMonitor(),
        callback_ckpt,   
    ],
    logger=pl.loggers.TensorBoardLogger("lightning_logs", name='chrom_accessibility'),
    )
trainer.fit(train_module,data_module)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Missing logger folder: lightning_logs/chrom_accessibility
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | model | ChromBERTGeneral | 62.8 M | train
---------------------------------------------------
18.9 M    Trainable params
43.9 M    Non-trainable params
62.8 M    Total params
251.021   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/chenqianqian/.conda/envs/demo/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


## Load fine-tuned checkpoint and evaluation
ChromBERT is now fine-tuned! You can access the tuned model directly using `pl_module.model`. However, please note that due to specific settings in flash-attention, you cannot change the dropout probability, which may introduce some randomness in the output.

For consistent results, we recommend saving the checkpoint and loading it with the original model to ensure you have the tuned model.

### Predicting genome-wide changes in transcriptome 

#### Load limited fine-tuned checkpoint of predicting genome-wide changes in transcriptome 
Due to time constraints in tutorial, we only use downsampled test data for evaluation.

In [19]:
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
import torch
import chrombert
import pandas as pd
from torchinfo import summary
import lightning.pytorch as pl 

In [24]:
import glob
gep_ft_ckpt = os.path.abspath(glob.glob('./lightning_logs/gep/version_0/checkpoints/*.ckpt')[0])
gep_ft_ckpt

'/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public/examples/tutorials/lightning_logs/gep/version_0/checkpoints/epoch=1-step=20.ckpt'

In [25]:

model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4, dropout=0)
ft_model = model_config.init_model(finetune_ckpt = gep_ft_ckpt)
summary(ft_model)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length including cls is 6392
Loading checkpoint from /shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public/examples/tutorials/lightning_logs/gep/version_0/checkpoints/epoch=1-step=20.ckpt
Loading from pl module, remove prefix 'model.'
Loaded 112/112 parameters


Layer (type:depth-idx)                                       Param #
ChromBERTGEP                                                 --
├─PoolFlankWindow: 1-1                                       --
│    └─ChromBERT: 2-1                                        --
│    │    └─BERTEmbedding: 3-1                               4,916,736
│    │    └─ModuleList: 3-2                                  51,978,240
├─GepHeader: 1-2                                             --
│    └─CistromeEmbeddingManager: 2-2                         --
│    └─Conv2d: 2-3                                           769
│    └─ReLU: 2-4                                             --
│    └─ResidualBlock: 2-5                                    --
│    │    └─Linear: 3-3                                      1,090,560
│    │    └─Linear: 3-4                                      1,049,600
│    │    └─LayerNorm: 3-5                                   2,048
│    │    └─Linear: 3-6                                      1,090

Evaluation

In [26]:
dataset_config = chrombert.get_preset_dataset_config("multi_flank_window",supervised_file = f'{gep_dir}/test_sample.csv', batch_size = 4, num_workers = 4)
dataset_config

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


DatasetConfig({'hdf5_file': '/home/chenqianqian/.cache/chrombert/data/hg38_6k_1kb.hdf5', 'supervised_file': '/home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/test_sample.csv', 'kind': 'MultiFlankwindowDataset', 'meta_file': '/home/chenqianqian/.cache/chrombert/data/config/hg38_6k_meta.json', 'ignore': False, 'ignore_object': None, 'batch_size': 4, 'num_workers': 4, 'shuffle': False, 'perturbation': False, 'perturbation_object': None, 'perturbation_value': 0, 'prompt_kind': None, 'prompt_regulator': None, 'prompt_regulator_cache_file': None, 'prompt_celltype': None, 'prompt_celltype_cache_file': None, 'fasta_file': None, 'flank_window': 4})

In [28]:
from tqdm import tqdm
import torchmetrics as tm
dl = dataset_config.init_dataloader()
ft_model.cuda()
with torch.no_grad():
    y_preds = []
    y_labels = []
    for idx, batch in enumerate(tqdm(dl, total=len(dl))):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch)[1].cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)

100%|██████████| 5/5 [00:08<00:00,  1.78s/it]

{'pearsonr': tensor(nan), 'spearmanr': tensor(0.), 'mse': tensor(0.3469), 'mae': tensor(0.4576), 'r2': tensor(-0.8482)}





#### Load the total dataset to fine-tune the checkpoint for predicting genome-wide changes in the transcriptome

Using a limited fine-tuned checkpoint results in poor evaluation performance; therefore, we load the entire dataset to fine-tune the checkpoint for predicting genome-wide changes in the transcriptome.
Due to time constraints in tutorial, we only use downsampled test data for evaluation.

In [30]:
gep_ft_ckpt = f'{gep_dir}/gep_fibroblast_to_myoblast.ckpt'

model_config = chrombert.get_preset_model_config("gep",gep_flank_window=4, dropout=0)
ft_model = model_config.init_model(finetune_ckpt = gep_ft_ckpt)
summary(ft_model)

dataset_config = chrombert.get_preset_dataset_config("multi_flank_window",supervised_file = f'{gep_dir}/test_sample.csv', batch_size = 4, num_workers = 4)


from tqdm import tqdm
import torchmetrics as tm
dl = dataset_config.init_dataloader()
ft_model.cuda()
with torch.no_grad():
    y_preds = []
    y_labels = []
    for batch in tqdm(dl,total=len(dl)):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch)[1].cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length including cls is 6392
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
Loaded 112/112 parameters
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


100%|██████████| 5/5 [00:08<00:00,  1.78s/it]

{'pearsonr': tensor(0.4961), 'spearmanr': tensor(0.4652), 'mse': tensor(0.1490), 'mae': tensor(0.3108), 'r2': tensor(0.2062)}





### Predicting genome-wide changes in chromatin accessibility 

#### Load limited fine-tuned checkpoint of predicting genome-wide changes in chromatin accessibility
Due to time constraints in tutorial, we only use downsampled test data for evaluation.

In [31]:
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
import torch
import chrombert
import pandas as pd
from torchinfo import summary
import lightning.pytorch as pl 

In [32]:
import glob
chrom_accessibility_ft_ckpt = os.path.abspath(glob.glob('./lightning_logs/chrom_accessibility/version_0/checkpoints/*.ckpt')[0])
chrom_accessibility_ft_ckpt

'/shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public/examples/tutorials/lightning_logs/chrom_accessibility/version_0/checkpoints/epoch=0-step=2.ckpt'

In [33]:
model_config = chrombert.get_preset_model_config("general")
ft_model = model_config.init_model(finetune_ckpt = chrom_accessibility_ft_ckpt,dropout=0)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length including cls is 6392
Loading checkpoint from /shared/chenqianqian/data_copy1/chenqianqian/finetune/test_model/ChromBERT_public/examples/tutorials/lightning_logs/chrom_accessibility/version_0/checkpoints/epoch=0-step=2.ckpt
Loading from pl module, remove prefix 'model.'
Loaded 110/110 parameters


In [34]:
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = f'{chrom_accessibility_dir}/test_sample.csv', batch_size = 32, num_workers = 4)


from tqdm import tqdm
import torchmetrics as tm
dl = dataset_config.init_dataloader()
ft_model = ft_model.cuda().eval()
with torch.no_grad():
    y_preds = []
    y_labels = []
    for batch in tqdm(dl,total=len(dl)):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch).cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


100%|██████████| 1/1 [00:01<00:00,  1.44s/it]

{'pearsonr': tensor(-0.1836), 'spearmanr': tensor(-0.0692), 'mse': tensor(0.6062), 'mae': tensor(0.4786), 'r2': tensor(-0.5403)}





#### Load the total dataset to fine-tune the checkpoint for predicting genome-wide changes in the chromatin accessibility

Using a limited fine-tuned checkpoint results in poor evaluation performance; therefore, we load the entire dataset to fine-tune the checkpoint for predicting genome-wide changes in the chromatin accessibility.
Due to time constraints in tutorial, we only use downsampled test data for evaluation.

In [35]:
chrom_accessibility_ft_ckpt = f'{chrom_accessibility_dir}/chrom_accessibility_fibroblast_to_myoblast.ckpt'
model_config = chrombert.get_preset_model_config("general")
ft_model = model_config.init_model(finetune_ckpt = chrom_accessibility_ft_ckpt,dropout=0)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length including cls is 6392
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt
Loaded 110/110 parameters


In [36]:
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = f'{chrom_accessibility_data_dir}/test_sample.csv', batch_size = 32, num_workers = 4)


from tqdm import tqdm
import torchmetrics as tm
dl = dataset_config.init_dataloader()
ft_model = ft_model.cuda().eval()
with torch.no_grad():
    y_preds = []
    y_labels = []
    for batch in tqdm(dl,total=len(dl)):
        for k in batch:
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].cuda()
        y_pred = ft_model(batch).cpu()
        y_label = batch['label'].cpu()
        y_preds.append(y_pred)
        y_labels.append(y_label)
    y_preds = torch.cat(y_preds)
    y_labels = torch.cat(y_labels)
predicts = y_preds.view(-1)
labels = y_labels.view(-1)
metrics_pearsonr = tm.PearsonCorrCoef()
metrics_spearmanr = tm.SpearmanCorrCoef()
metrics_mse = tm.MeanSquaredError()
metrics_mae = tm.MeanAbsoluteError()
metrics_r2 = tm.R2Score()
score_pearsonr = metrics_pearsonr(predicts, labels)
score_spearmanr = metrics_spearmanr(predicts, labels)
score_mse = metrics_mse(predicts, labels)
score_mae = metrics_mae(predicts, labels)
score_r2 = metrics_r2(predicts, labels)
scores = {
    "pearsonr": score_pearsonr,
    "spearmanr": score_spearmanr,
    "mse": score_mse,
    "mae": score_mae,
    "r2": score_r2,
    }
print(scores)

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


100%|██████████| 1/1 [00:01<00:00,  1.32s/it]

{'pearsonr': tensor(0.8971), 'spearmanr': tensor(0.8857), 'mse': tensor(0.0807), 'mae': tensor(0.1940), 'r2': tensor(0.7950)}





## Use tuned to get regulator embedding and identified key regulator in cell state transition

Using a limited fine-tuned checkpoint results in poor evaluation performance; therefore, we load the entire dataset to fine-tune the checkpoint to get regulator embedding.
Due to time constraints in tutorial, we only use downsampled data.


### Get regulator embedding and regulator rank in transcriptome

In [37]:
model_tuned = chrombert.get_preset_model_config(
    "gep", 
    gep_flank_window = 4,
    dropout = 0,
    finetune_ckpt = f'{gep_dir}/gep_fibroblast_to_myoblast.ckpt').init_model() # use absolute path here, to avoid mixing of preset

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
use organisim hg38; max sequence length including cls is 6392
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/transcriptome/gep_fibroblast_to_myoblast.ckpt
Loaded 112/112 parameters


In [38]:
model_emb = model_tuned.get_embedding_manager().cuda()
summary(model_emb)

Layer (type:depth-idx)                                       Param #
ChromBERTEmbedding                                           --
├─PoolFlankWindow: 1-1                                       --
│    └─ChromBERT: 2-1                                        --
│    │    └─BERTEmbedding: 3-1                               4,916,736
│    │    └─ModuleList: 3-2                                  51,978,240
├─CistromeEmbeddingManager: 1-2                              --
Total params: 56,894,976
Trainable params: 56,894,976
Non-trainable params: 0

In [39]:
dataset_config = chrombert.get_preset_dataset_config("multi_flank_window",supervised_file = f'{gep_dir}/up_data_sample.csv', batch_size = 32, num_workers = 4)
dl = dataset_config.init_dataloader()
up_gep_embs = []
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        up_gep_embs.append(emb)
up_gep_embs = torch.cat(up_gep_embs)
up_gep_embs.shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


100%|██████████| 4/4 [00:43<00:00, 10.80s/it]


torch.Size([100, 1064, 768])

In [40]:
dataset_config = chrombert.get_preset_dataset_config("multi_flank_window",supervised_file = f'{gep_dir}/nochange_data_sample.csv', batch_size = 32, num_workers = 4)
dl = dataset_config.init_dataloader()
nochange_gep_embs = []
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        nochange_gep_embs.append(emb)
nochange_gep_embs = torch.cat(nochange_gep_embs)
nochange_gep_embs.shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


100%|██████████| 4/4 [00:43<00:00, 10.95s/it]


torch.Size([100, 1064, 768])

We consider only factors below, remove histone modifications and chromatin accessibility. Then we can get similarity between factors, which represent the potential interactions between factors. 

In [41]:
with open(os.path.join(base_dir, "config","hg38_6k_factors_list.txt"),"r") as f:
    factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]


indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
up_gep_embs = up_gep_embs.mean(axis=0)[indices]
nochange_gep_embs = nochange_gep_embs.mean(axis=0)[indices]

In [42]:
up_gep_embs.shape, nochange_gep_embs.shape

(torch.Size([996, 768]), torch.Size([996, 768]))

In [43]:
from sklearn.metrics.pairwise import cosine_similarity
gep_similarity = [cosine_similarity(up_gep_embs[i].reshape(1, -1), nochange_gep_embs[i].reshape(1, -1))[0, 0] for i in range(up_gep_embs.shape[0])]
gep_similarity_df = pd.DataFrame({'factors':names,'similarity':gep_similarity}).sort_values(by='similarity').reset_index(drop=True)
gep_similarity_df['rank']=gep_similarity_df.index + 1
gep_similarity_df

Unnamed: 0,factors,similarity,rank
0,rpe,0.902802,1
1,klf15,0.925641,2
2,nr3c2,0.926707,3
3,myf5,0.938897,4
4,pgr,0.950597,5
...,...,...,...
991,dbp,0.998239,992
992,tfcp2,0.998259,993
993,mafk,0.998279,994
994,tfam,0.998362,995


In [44]:
gep_similarity_df[gep_similarity_df['factors']=='myod1']

Unnamed: 0,factors,similarity,rank
17,myod1,0.971539,18


### Get regulator embedding and regulator rank in chromatin accessibility

In [45]:
model_tuned = chrombert.get_preset_model_config(
    "general", 
    dropout = 0,
    finetune_ckpt = f'{chrom_accessibility_dir}/chrom_accessibility_fibroblast_to_myoblast.ckpt').init_model() # use absolute path here, to avoid mixing of preset

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt
use organisim hg38; max sequence length including cls is 6392
Loading checkpoint from /home/chenqianqian/.cache/chrombert/data/demo/transdifferentiation/chrom_accessibility/chrom_accessibility_fibroblast_to_myoblast.ckpt
Loaded 110/110 parameters


In [46]:
model_emb = model_tuned.get_embedding_manager().cuda()
summary(model_emb)

Layer (type:depth-idx)                                  Param #
ChromBERTEmbedding                                      --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               --
│    │    └─TokenEmbedding: 3-1                         7,680
│    │    └─PositionalEmbedding: 3-2                    4,909,056
│    │    └─Dropout: 3-3                                --
│    └─ModuleList: 2-2                                  --
│    │    └─EncoderTransformerBlock: 3-4                6,497,280
│    │    └─EncoderTransformerBlock: 3-5                6,497,280
│    │    └─EncoderTransformerBlock: 3-6                6,497,280
│    │    └─EncoderTransformerBlock: 3-7                6,497,280
│    │    └─EncoderTransformerBlock: 3-8                6,497,280
│    │    └─EncoderTransformerBlock: 3-9                6,497,280
│    │    └─EncoderTransformerBlock: 3-10               6,497,280
│    │    └─EncoderTransformerBlock: 3-11          

In [47]:
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = f'{chrom_accessibility_dir}/up_data_sample.csv', batch_size = 32, num_workers = 4)
up_chrom_acc_embs=[]
dl = dataset_config.init_dataloader()
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        up_chrom_acc_embs.append(emb)
up_chrom_acc_embs = torch.cat(up_chrom_acc_embs,dim=0)
up_chrom_acc_embs.shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


  0%|          | 0/4 [00:00<?, ?it/s]

100%|██████████| 4/4 [00:05<00:00,  1.36s/it]


torch.Size([100, 1064, 768])

In [48]:
dataset_config = chrombert.get_preset_dataset_config("general",supervised_file = f'{chrom_accessibility_dir}/nochange_data_sample.csv', batch_size = 32, num_workers = 4)
nochange_chrom_acc_embs=[]
dl = dataset_config.init_dataloader()
for batch in tqdm(dl):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch).cpu()
        nochange_chrom_acc_embs.append(emb)
nochange_chrom_acc_embs = torch.cat(nochange_chrom_acc_embs,dim=0)
nochange_chrom_acc_embs.shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


100%|██████████| 4/4 [00:05<00:00,  1.39s/it]


torch.Size([100, 1064, 768])

In [49]:
with open(os.path.join(base_dir, "config","hg38_6k_factors_list.txt"),"r") as f:
    factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]


indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
up_chrom_acc_embs = up_chrom_acc_embs.mean(axis=0)[indices]
nochange_chrom_acc_embs = nochange_chrom_acc_embs.mean(axis=0)[indices]

In [50]:
up_chrom_acc_embs.shape, nochange_chrom_acc_embs.shape

(torch.Size([996, 768]), torch.Size([996, 768]))

In [51]:
from sklearn.metrics.pairwise import cosine_similarity
chrom_acc_similarity = [cosine_similarity(up_chrom_acc_embs[i].reshape(1, -1), nochange_chrom_acc_embs[i].reshape(1, -1))[0, 0] for i in range(up_chrom_acc_embs.shape[0])]
chrom_acc_similarity_df = pd.DataFrame({'factors':names,'similarity':chrom_acc_similarity}).sort_values(by='similarity').reset_index(drop=True)
chrom_acc_similarity_df['rank']=chrom_acc_similarity_df.index + 1
chrom_acc_similarity_df

Unnamed: 0,factors,similarity,rank
0,myod1,0.038196,1
1,myf5,0.139169,2
2,myog,0.158149,3
3,pax3,0.169502,4
4,tead1,0.264310,5
...,...,...,...
991,zbtb10,0.979514,992
992,znf250,0.979778,993
993,dr1,0.980107,994
994,hoxa1,0.980110,995


In [52]:
chrom_acc_similarity_df[chrom_acc_similarity_df['factors']=='myod1']

Unnamed: 0,factors,similarity,rank
0,myod1,0.038196,1


### Average ranking from these two modalities

In [53]:
average_rank_df = pd.merge(gep_similarity_df, chrom_acc_similarity_df, on='factors', how='inner', suffixes=('_gep', '_chrom_acc'))
average_rank_df['averge_rank'] = ((average_rank_df['rank_gep']+average_rank_df['rank_chrom_acc'])/2).rank().astype(int)
average_rank_df=average_rank_df.sort_values(by='averge_rank')
average_rank_df

Unnamed: 0,factors,similarity_gep,rank_gep,similarity_chrom_acc,rank_chrom_acc,averge_rank
3,myf5,0.938897,4,0.139169,2,1
6,dux4,0.954232,7,0.462057,6,2
2,nr3c2,0.926707,3,0.712065,12,3
11,tead1,0.962442,12,0.264310,5,4
17,myod1,0.971539,18,0.038196,1,5
...,...,...,...,...,...,...
992,tfcp2,0.998259,993,0.975744,966,992
987,mafg,0.998105,988,0.977242,977,993
973,pitx1,0.997836,974,0.979403,991,993
983,znf250,0.997955,984,0.979778,993,995


In [54]:
average_rank_df[average_rank_df['factors']=='myod1']

Unnamed: 0,factors,similarity_gep,rank_gep,similarity_chrom_acc,rank_chrom_acc,averge_rank
17,myod1,0.971539,18,0.038196,1,5
