# Example for context-specific TRN: functional collaborations with EZH2 on funtional distinct loci
Inference of transcriptional regulatory networks (TRNs) at specific loci is a complex and dynamic process. In this tutorial, we will guide you through context-specific TRN analysis using ChromBERTs, with EZH2 serving as an example of functional collaborations at distinct loci. 

**Attention: You should go through this [ tutorial ](https://chrombert.readthedocs.io/en/latest/tutorial_finetuning_ChromBERT.html) at first to get familiar with the basic usage of ChromBERT.**

## Preprocessing dataset

To identify classical and non-classical EZH2 sites in human embryonic stem cells (hESCs), we utilize the EZH2 peak dataset (GSM1003524) and the H3K27me3 peak dataset (GSM1498900). Classical EZH2 sites are defined as regions where EZH2 co-localizes with H3K27me3, while non-classical EZH2 sites are identified as regions where EZH2 is present without H3K27me3 co-localization.


In [1]:
import  os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # to selected gpu used 

import sys 
import torch 
import numpy as np 
import pandas as pd 
from tqdm import tqdm 
from matplotlib import pyplot as plt
import seaborn as sns
import chrombert
from torchinfo import summary

import lightning.pytorch as pl
basedir =  os.path.expanduser("~/.cache/chrombert/data")



  from pandas.core import (


In [2]:
peak_ezh2 = os.path.join(basedir, "demo", "ezh2", "hESC_GSM1003524_EZH2.bed")
!head -3 {peak_ezh2}

chr3	93470270	93470880	peak7668	2339	.	23.98034	243.39191	233.90086	232
chr2	91477646	91478694	peak6148	1127	.	10.28381	119.74835	112.79007	430
chr16	46390276	46390857	peak4186	1039	.	8.55891	110.78978	103.94416	350


In [3]:
peak_h3k27me3 = os.path.join(basedir, "demo", "ezh2", "hESC_GSM1498900_H3K27me3.bed")
!head -3 {peak_h3k27me3}

chr10	100480580	100483730	peak1145	6.37280
chr10	100519337	100521146	peak1146	6.86372
chr10	100655029	100656371	peak1147	15.71588


In [4]:
# Align genomic coordinates from the narrowPeak file to the Human-Cistrome-6k dataset regions
from chrombert.scripts.chrombert_make_dataset import get_overlap
ref_regions = os.path.join(basedir, "config", "hg38_6k_1kb_region.bed")
df1 = get_overlap(
    supervised = peak_ezh2, 
    regions = ref_regions,
    no_filter = False,
).assign(label = lambda df: df["label"] > 0 )
df2 = get_overlap(
    supervised = peak_h3k27me3, 
    regions = ref_regions,
    no_filter = True,
).assign(label = lambda df: df["label"] > 0 )
df1.head(), df2.head()

(  chrom   start     end  build_region_index  label
 0  chr1  870000  871000                 174   True
 1  chr1  905000  906000                 204   True
 2  chr1  923000  924000                 220   True
 3  chr1  924000  925000                 221   True
 4  chr1  925000  926000                 222   True,
   chrom  start    end  build_region_index  label
 0  chr1  10000  11000                   0  False
 1  chr1  16000  17000                   1  False
 2  chr1  17000  18000                   2  False
 3  chr1  29000  30000                   3  False
 4  chr1  30000  31000                   4  False)

In [5]:
df_supervised = df1.rename(columns = {"label": "EZH2"}).merge(df2).assign(label = lambda df: df["label"])
df_supervised

Unnamed: 0,chrom,start,end,build_region_index,EZH2,label
0,chr1,870000,871000,174,True,True
1,chr1,905000,906000,204,True,True
2,chr1,923000,924000,220,True,True
3,chr1,924000,925000,221,True,False
4,chr1,925000,926000,222,True,True
...,...,...,...,...,...,...
11003,chrX,154750000,154751000,2134828,True,True
11004,chrY,5001000,5002000,2135730,True,False
11005,chrY,10994000,10995000,2136403,True,False
11006,chrY,26670000,26671000,2137888,True,False


In [6]:
df_supervised.groupby("label").size() # that's a near balanced dataset

label
False    5272
True     5736
dtype: int64

In [7]:
# Then we split the dataset into training, validation and test sets

from sklearn.model_selection import train_test_split
df_train, df_temp = train_test_split(df_supervised, test_size=0.2, random_state=42, stratify = df_supervised['label'])
df_valid, df_test = train_test_split(df_temp, test_size=0.5, random_state=42, stratify = df_temp['label'])

os.makedirs("tmp_ezh2", exist_ok=True)
df_train.to_csv(os.path.join("tmp_ezh2", "train.csv"))
df_valid.to_csv(os.path.join("tmp_ezh2", "valid.csv"))
df_test.to_csv(os.path.join("tmp_ezh2", "test.csv"))

len(df_train), len(df_valid), len(df_test)

(8806, 1101, 1101)

## Fine-tune

In this section, we provide a tutorial on fine-tuning ChromBERTs for our specific task. The process closely follows the original ChromBERTs workflow, with a few important modifications:

- **Dataset Preparation**: The `ignore_object` parameter is used to omit H3K27me3-related cistromes from the original ChromBERTs dataset, ensuring H3K27me3 does not interfere with the analysis.  
- **Model Instantiation**: A special `ignore_index` parameter, derived from the dataset, is introduced to properly configure the model.  

Let's get started!

### Instructions for dataset: omit specified regulators

In [8]:
dc = chrombert.get_preset_dataset_config(
    "general", 
    supervised_file = None, 
    ignore = False, ignore_object = "h3k27me3" # turn off omission
    )
ds = dc.init_dataset(supervised_file = os.path.join("tmp_ezh2", "train.csv"))
ds[1]["input_ids"].shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


torch.Size([6391])

In [9]:
# We omit the h3k27me3 related cistrome, to avoid data leakage
dc = chrombert.get_preset_dataset_config(
    "general", 
    supervised_file = None, 
    ignore = True, ignore_object = "h3k27me3" 
    )
ds = dc.init_dataset(supervised_file = os.path.join("tmp_ezh2", "train.csv"))

# Get ignore_index used to instantiate model. 
# Currently, we only support same ignore object in one dataset, 
# so it's ok to get ignore_index from any sample. 
ignore_index = ds[0]["ignore_index"] 
ds[1]["input_ids"].shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


torch.Size([6188])

As shown above, the dataset functions as expected after omitting the specified cistromes. However, the input sequence length is reduced to 6185 from 6391, as 206 H3K27me3-related cistromes are omitted and do not participate in the training process.

A small note: the model is fine-tuned using PyTorch Lightning, and the dataset is wrapped in the `lightning.pytorch.LightningDataModule` class for seamless integration. 

In [10]:
data_module = chrombert.LitChromBERTFTDataModule(
    config = dc,
    train_params = dict(supervised_file = os.path.join("tmp_ezh2", "train.csv")),
    val_params = dict(supervised_file = os.path.join("tmp_ezh2", "valid.csv")),
    test_params = dict(supervised_file = os.path.join("tmp_ezh2", "test.csv")),
)
data_module

<chrombert.finetune.dataset.data_module.LitChromBERTFTDataModule at 0x7f6abdb92f80>

### Instantiate the Model
Next, we can instantiate the model using the ignore_index parameter.

In [11]:
model = chrombert.get_preset_model_config(
    "general", 
    ignore = True, ignore_index =ignore_index  # ignore_index from above
).init_model()
model.freeze_pretrain(trainable=2) # we only fine-tune the last two layers
summary(model, depth=2)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length is 6391
Ignoring 203 cistromes and 1 regulators


Layer (type:depth-idx)                                  Param #
ChromBERTGeneral                                        --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               (4,916,736)
│    └─ModuleList: 2-2                                  51,978,240
├─GeneralHeader: 1-2                                    --
│    └─CistromeEmbeddingManager: 2-3                    --
│    └─Conv2d: 2-4                                      769
│    └─ReLU: 2-5                                        --
│    └─ResidualBlock: 2-6                               3,249,152
│    └─ResidualBlock: 2-7                               2,166,528
│    └─ResidualBlock: 2-8                               460,032
│    └─Linear: 2-9                                      257
Total params: 62,771,714
Trainable params: 18,871,298
Non-trainable params: 43,900,416

### Fine-tune

We fine-tune the model using PyTorch Lightning, employing a straightforward configuration to process parameters. The tuning is performed on a limited dataset to save time.

Note: The tuning process involves randomness, so results may vary. For improved performance, consider increasing the number of epochs and expanding the size of the dataset used.


In [12]:
tc = chrombert.finetune.train.TrainConfig(
    kind = "classification",
    loss = "bce", # specify "bce" to use Binary Cross-Entropy (BCE) loss. Use "focal" to apply Focal Loss instead.
    max_epochs = 1,
    lr = 1e-4
)
pl_module = tc.init_pl_module(model) # wrap model with PyTorch Lightning module
type(pl_module)

chrombert.finetune.train.pl_module.ClassificationPLModule

Next, we begin the tuning process!  
The trainer will save logs in a format compatible with TensorBoard, and multiple checkpoints may be generated during the process.  
For this tutorial, however, we will use the latest model parameters instead of the checkpoints, as the tuning is insufficient.  

In [13]:
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # for tensorboard compatibility
callback_ckpt = pl.callbacks.ModelCheckpoint( monitor = f"{tc.tag}_validation/{tc.loss}", mode = "min")
# 
# 
trainer = pl.Trainer(
    max_epochs = tc.max_epochs,
    accelerator = "gpu",
    precision = "bf16-mixed",
    fast_dev_run = False,
    accumulate_grad_batches = 16, 
    logger = pl.loggers.TensorBoardLogger(os.path.join("tmp_ezh2","logs"), name = "ezh2"),
    val_check_interval = 128,
    limit_val_batches = 128,
    log_every_n_steps = 1,
    callbacks = [ callback_ckpt, pl.callbacks.LearningRateMonitor() ],
)
trainer.fit(pl_module, data_module)
pl_module.save_ckpt(os.path.join("tmp_ezh2", "ezh2.ckpt"))

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: tmp_ezh2/logs/ezh2
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name  | Type             | Params
-------------------------------------------
0 | model | ChromBERTGeneral | 62.8 M
-------------------------------------------
18.9 M    Trainable params
43.9 M    Non-trainable params
62.8 M    Total params
251.087   Total estimated model params siz

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.



## Use Fine-Tuned Model to Obtain Regulator Embeddings  

ChromBERT has been successfully fine-tuned! You can directly access the tuned model using `pl_module.model`. However, due to specific settings in flash-attention, the dropout probability cannot be modified by `model.eval()`, which may introduce some randomness in the output.  

To ensure consistent results, we recommend saving the checkpoint and loading it into the original model. This approach guarantees you are working with the fine-tuned version.

In [14]:
model_tuned = chrombert.get_preset_model_config(
    "general", 
    dropout = 0,
    ignore = True, ignore_index =ignore_index,   # ignore_index from above
    finetune_ckpt = os.path.abspath(os.path.join("tmp_ezh2", "ezh2.ckpt")) # use absolute path here
).init_model()
# or use model_tuned = pl_module.model
summary(model_tuned, depth = 2)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/yangdongxu/work/source/repos/ChromBERT/examples/tutorials/tmp_ezh2/ezh2.ckpt
use organisim hg38; max sequence length is 6391
Ignoring 203 cistromes and 1 regulators
Loading checkpoint from /home/yangdongxu/work/source/repos/ChromBERT/examples/tutorials/tmp_ezh2/ezh2.ckpt
Loaded 110/110 parameters


Layer (type:depth-idx)                                  Param #
ChromBERTGeneral                                        --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               4,916,736
│    └─ModuleList: 2-2                                  51,978,240
├─GeneralHeader: 1-2                                    --
│    └─CistromeEmbeddingManager: 2-3                    --
│    └─Conv2d: 2-4                                      769
│    └─ReLU: 2-5                                        --
│    └─ResidualBlock: 2-6                               3,249,152
│    └─ResidualBlock: 2-7                               2,166,528
│    └─ResidualBlock: 2-8                               460,032
│    └─Linear: 2-9                                      257
Total params: 62,771,714
Trainable params: 62,771,714
Non-trainable params: 0

Then we can get the embedding manager following the instruction of tutorials about extracting embeddings.

In [15]:
model_emb = model_tuned.get_embedding_manager().cuda()
summary(model_emb)

Ignoring 203 cistromes and 1 regulators


Layer (type:depth-idx)                                  Param #
ChromBERTEmbedding                                      --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               --
│    │    └─TokenEmbedding: 3-1                         7,680
│    │    └─PositionalEmbedding: 3-2                    4,909,056
│    │    └─Dropout: 3-3                                --
│    └─ModuleList: 2-2                                  --
│    │    └─EncoderTransformerBlock: 3-4                6,497,280
│    │    └─EncoderTransformerBlock: 3-5                6,497,280
│    │    └─EncoderTransformerBlock: 3-6                6,497,280
│    │    └─EncoderTransformerBlock: 3-7                6,497,280
│    │    └─EncoderTransformerBlock: 3-8                6,497,280
│    │    └─EncoderTransformerBlock: 3-9                6,497,280
│    │    └─EncoderTransformerBlock: 3-10               6,497,280
│    │    └─EncoderTransformerBlock: 3-11          

In [16]:
dc_test = data_module.test_config
ds_test = dc_test.init_dataset()
dl_test = dc_test.init_dataloader(batch_size = 1)
len(ds_test), list(ds_test[0].keys())

(1101,
 ['input_ids',
  'position_ids',
  'region',
  'build_region_index',
  'ignore_index',
  'label'])

In [17]:
# Obtain embeddings for both classical and non-classical EZH2 sites
embs_classicial = []
embs_nonclassicial = []
for batch in tqdm(dl_test):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch)
    if batch["label"].item() == 1:
        embs_classicial.append(emb)
    else:
        embs_nonclassicial.append(emb)

print(len(embs_classicial), len(embs_nonclassicial))
embs_classicial = torch.cat(embs_classicial, dim = 0).cpu().numpy().mean(axis = 0)
embs_nonclassicial = torch.cat(embs_nonclassicial, dim = 0).cpu().numpy().mean(axis = 0)
embs_classicial.shape, embs_nonclassicial.shape

100%|██████████| 1101/1101 [01:28<00:00, 12.39it/s]


573 528


((1072, 768), (1072, 768))

We focus exclusively on transcription factors, ignoring histone modifications and chromatin accessibility. This allows us to calculate the similarity between transcription factors, representing their potential interactions.

In [18]:
with open(os.path.join(basedir, "config","hg38_6k_factors_list.txt"),"r") as f:
    factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]
factors[:3], len(factors)

(['adnp', 'aebp2', 'aff1'], 991)

In [19]:
factors[-1]

'zzz3'

In [20]:
indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
embs_classicial = embs_classicial[indices]
embs_nonclassicial = embs_nonclassicial[indices]


In [21]:
from sklearn.metrics.pairwise import cosine_similarity
cos_classicial_matrix = cosine_similarity(embs_classicial)
cos_nonclassicial_matrix = cosine_similarity(embs_nonclassicial)
df_cos_classicial = pd.DataFrame(cos_classicial_matrix, columns = names, index = names)
df_cos_nonclassicial = pd.DataFrame(cos_nonclassicial_matrix, columns = names, index = names)

In [22]:
# we define threshold to select the most related regulators pairs
thre_class = np.percentile(cos_classicial_matrix.flatten(), 95)
thre_nonclass = np.percentile(cos_nonclassicial_matrix.flatten(), 95)
thre_class, thre_nonclass

(0.5776264667510986, 0.5518490076065063)

Now, we identify TRNs associated with the non-classical functions of EZH2. As you can see, factors related to the classical functions of EZH2 are associated with the Polycomb complex, such as SUZ12. In contrast, factors linked to EZH2's non-classical functions tend to be associated with transcriptional activation, including EP300 and STAT3.

In [23]:
df_cos_ezh2 = pd.DataFrame(index =names, data = {"classical":df_cos_classicial.loc["ezh2",:],"nonclassical":df_cos_nonclassicial.loc["ezh2",:]})
df_cos_ezh2["diff"] = df_cos_ezh2["classical"] - df_cos_ezh2["nonclassical"]
df_cos_ezh2

Unnamed: 0,classical,nonclassical,diff
adnp,0.193229,0.359759,-0.166530
aebp2,0.190192,0.309486,-0.119294
aff1,0.396427,0.479881,-0.083454
aff4,0.300634,0.452184,-0.151550
ago1,0.234017,0.291643,-0.057626
...,...,...,...
zscan5a,0.218051,0.334104,-0.116052
zta,0.181115,0.251131,-0.070015
zxdb,0.213376,0.294275,-0.080898
zxdc,0.149416,0.287630,-0.138214


In [24]:
df_cos_ezh2.query("classical > @thre_class ").sort_values("diff", ascending = False).head(10)

Unnamed: 0,classical,nonclassical,diff
ezh1,0.644102,0.578448,0.0656538
pcgf1,0.627553,0.564524,0.06302953
kdm2b,0.600312,0.543989,0.05632281
jarid2,0.660526,0.609975,0.05055112
rybp,0.594199,0.549152,0.04504716
suz12,0.875248,0.833695,0.04155296
bcor,0.616664,0.587787,0.02887654
eed,0.597332,0.576189,0.0211432
ezh2,1.0,1.0,-3.576279e-07
cbx8,0.584771,0.621309,-0.03653848


In [25]:
df_cos_ezh2.query("nonclassical > @thre_nonclass ").sort_values("diff", ascending = True).head(10)

Unnamed: 0,classical,nonclassical,diff
foxm1,0.395081,0.604784,-0.209703
med1,0.349747,0.556351,-0.206604
stat3,0.424211,0.614199,-0.189988
ep300,0.411506,0.59994,-0.188434
hinfp,0.391241,0.577772,-0.186532
rela,0.385995,0.568461,-0.182466
smarca4,0.383161,0.552061,-0.1689
brca1,0.415508,0.583485,-0.167977
stat1,0.431552,0.594642,-0.16309
e2f4,0.423399,0.56879,-0.145391


## The end
This tutorial offers a comprehensive guide to context-specific TRN inference, using EZH2's functional collaborations as an example.  
We hope you find it both helpful and informative.  
If you have any questions or require further assistance, please don't hesitate to reach out. Thank you for following along!