# Tutorial for locus specific TRN inference
Inference of transcriptional regulatory networks (TRNs) at specific loci is complex and dynamic, posing significant challenges. However, ChromBERTs offer a solution for this intricate task. In this tutorial, we will guide you through the process of locus-specific TRN inference, using the example of the non-classical functions of EZH2.

## Preprocessing dataset
To define classical and non-classical sites of EZH2 in human embryonic stem cells (hESC), we use the EZH2 peak dataset (GSM1003524) and the H3K27me3 peak dataset (GSM1498900). Specifically, we define classical EZH2 sites as those where EZH2 co-localizes with H3K27me3. Conversely, non-classical EZH2 sites are identified where EZH2 does not co-localize with H3K27me3.

In [1]:
import  os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # to selected gpu used 

import sys 
import torch 
import numpy as np 
import pandas as pd 
from tqdm import tqdm 
from matplotlib import pyplot as plt
import seaborn as sns
import chrombert
from torchinfo import summary

import lightning.pytorch as pl
basedir = "/home/yangdongxu/repos/ChromBERT_reorder/data"  



  from pandas.core import (


In [2]:
peak_ezh2 = os.path.join(basedir, "demo", "ezh2", "hESC_GSM1003524_EZH2.bed")
!head -3 {peak_ezh2}

chr3	93470270	93470880	peak7668	2339	.	23.98034	243.39191	233.90086	232
chr2	91477646	91478694	peak6148	1127	.	10.28381	119.74835	112.79007	430
chr16	46390276	46390857	peak4186	1039	.	8.55891	110.78978	103.94416	350


In [3]:
peak_h3k27me3 = os.path.join(basedir, "demo", "ezh2", "hESC_GSM1498900_H3K27me3.bed")
!head -3 {peak_h3k27me3}

chr10	100480580	100483730	peak1145	6.37280
chr10	100519337	100521146	peak1146	6.86372
chr10	100655029	100656371	peak1147	15.71588


In [4]:
# we first align peaks to our predifined bins
from chrombert.scripts.chrombert_make_dataset import get_overlap
ref_regions = os.path.join(basedir, "config", "hg38_6k_1kb_region.bed")
df1 = get_overlap(
    supervised = peak_ezh2, 
    regions = ref_regions,
    no_filter = False,
).assign(label = lambda df: df["label"] > 0 )
df2 = get_overlap(
    supervised = peak_h3k27me3, 
    regions = ref_regions,
    no_filter = True,
).assign(label = lambda df: df["label"] > 0 )
df1.head(), df2.head()

(  chrom   start     end  build_region_index  label
 0  chr1  870000  871000                 174   True
 1  chr1  905000  906000                 204   True
 2  chr1  923000  924000                 220   True
 3  chr1  924000  925000                 221   True
 4  chr1  925000  926000                 222   True,
   chrom  start    end  build_region_index  label
 0  chr1  10000  11000                   0  False
 1  chr1  16000  17000                   1  False
 2  chr1  17000  18000                   2  False
 3  chr1  29000  30000                   3  False
 4  chr1  30000  31000                   4  False)

In [5]:
df_supervised = df1.rename(columns = {"label": "EZH2"}).merge(df2).assign(label = lambda df: df["label"])
df_supervised

Unnamed: 0,chrom,start,end,build_region_index,EZH2,label
0,chr1,870000,871000,174,True,True
1,chr1,905000,906000,204,True,True
2,chr1,923000,924000,220,True,True
3,chr1,924000,925000,221,True,False
4,chr1,925000,926000,222,True,True
...,...,...,...,...,...,...
11003,chrX,154750000,154751000,2134828,True,True
11004,chrY,5001000,5002000,2135730,True,False
11005,chrY,10994000,10995000,2136403,True,False
11006,chrY,26670000,26671000,2137888,True,False


In [6]:
df_supervised.groupby("label").size() # that's a near balanced dataset

label
False    5272
True     5736
dtype: int64

In [7]:
# we then split the dataset into training, validation and test sets
from sklearn.model_selection import train_test_split
df_train, df_temp = train_test_split(df_supervised, test_size=0.2, random_state=42, stratify = df_supervised['label'])
df_valid, df_test = train_test_split(df_temp, test_size=0.5, random_state=42, stratify = df_temp['label'])

os.makedirs("tmp_ezh2", exist_ok=True)
df_train.to_csv(os.path.join("tmp_ezh2", "train.csv"))
df_valid.to_csv(os.path.join("tmp_ezh2", "valid.csv"))
df_test.to_csv(os.path.join("tmp_ezh2", "test.csv"))

len(df_train), len(df_valid), len(df_test)

(8806, 1101, 1101)

## Fine-tune
In this section, we provide a tutorial for fine-tuning ChromBERTs to adapt to our specific task. The steps are similar to the original ChromBERTs process, with a few key modifications:

- Dataset Preparation: We use the ignore_object parameter to exclude H3K27me3-related cistromes from the original ChromBERTs dataset, preventing interference from H3K27me3.
- Model Instantiation: We introduce a special ignore_index parameter, derived from the dataset, to properly instantiate the model.  
Let's get started!

### Intruction for dataset ignore given factor

In [8]:
dc = chrombert.get_preset_dataset_config(
    "general", 
    supervised_file = None, 
    ignore = False, ignore_object = "h3k27me3" # turn off ignore
    )
ds = dc.init_dataset(supervised_file = os.path.join("tmp_ezh2", "train.csv"))
ds[1]["input_ids"].shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


torch.Size([6391])

In [9]:
dc = chrombert.get_preset_dataset_config(
    "general", 
    supervised_file = None, 
    ignore = True, ignore_object = "h3k27me3" # we ignore the h3k27me3 related cistrome, to avoid data leakage
    )
ds = dc.init_dataset(supervised_file = os.path.join("tmp_ezh2", "train.csv"))
ignore_index = ds[0]["ignore_index"] # to get ignore_index used to build model. Currently, we only support same ignore object in one dataset, so we can get ignore_index from any sample. 
ds[1]["input_ids"].shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json


torch.Size([6185])

As you can see above, the dataset, after ignoring the specified cistromes, works as normal. However, the length of the input sequence is 6185 instead of 6391. A total of 206 H3K27me3-related cistromes are ignored and do not participate in the training process.

Additionally, a small note: we tune the model using PyTorch Lightning, so the dataset is wrapped in the `lightning.pytorch.LightningDataModule` class.



In [10]:
data_module = chrombert.LitChromBERTFTDataModule(
    config = dc,
    train_params = dict(supervised_file = os.path.join("tmp_ezh2", "train.csv")),
    val_params = dict(supervised_file = os.path.join("tmp_ezh2", "valid.csv")),
    test_params = dict(supervised_file = os.path.join("tmp_ezh2", "test.csv")),
)
data_module

<chrombert.finetune.dataset.data_module.LitChromBERTFTDataModule at 0x7f79676db460>

### Instantiate the Model
Next, we can instantiate the model using the ignore_index parameter.

In [11]:
model = chrombert.get_preset_model_config(
    "general", 
    ignore = True, ignore_index =ignore_index  # ignore_index from above
).init_model()
model.freeze_pretrain(trainable=2) # we only fine-tune the last two layers
summary(model, depth=2)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
use organisim hg38; max sequence length including cls is 6392


Layer (type:depth-idx)                                  Param #
ChromBERTGeneral                                        --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               (4,916,736)
│    └─ModuleList: 2-2                                  51,978,240
├─GeneralHeader: 1-2                                    --
│    └─CistromeEmbeddingManager: 2-3                    --
│    └─Conv2d: 2-4                                      769
│    └─ReLU: 2-5                                        --
│    └─ResidualBlock: 2-6                               3,230,720
│    └─ResidualBlock: 2-7                               2,166,528
│    └─ResidualBlock: 2-8                               460,032
│    └─Linear: 2-9                                      257
Total params: 62,753,282
Trainable params: 18,852,866
Non-trainable params: 43,900,416

## Fine-tune
We fine-tune the model using PyTorch Lightning. A simple configuration is created to process parameters, and tuning is performed on a limited dataset to save time.

Note: The tuning process is random, so results may vary. To achieve the best results, consider increasing the number of epochs and the size of the dataset used.



In [12]:
tc = chrombert.finetune.train.TrainConfig(
    kind = "classification",
    loss = "bce", # use bce loss because it's a balanced binary classification task
    max_epochs = 1,
    lr = 1e-4
)
pl_module = tc.init_pl_module(model) # wrap model with PyTorch Lightning module
type(pl_module)

chrombert.finetune.train.pl_module.ClassificationPLModule

Then we start tuning!   
The trainer will save logs in a format supported by TensorBoard, and several checkpoints may be saved during the process.  
However, in this tutorial, we use the latest model parameters instead of checkpoints due to limited tuning.



In [13]:
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # for tensorboard compatibility
callback_ckpt = pl.callbacks.ModelCheckpoint( monitor = f"{tc.tag}_validation/{tc.loss}", mode = "min")
# 
# 
trainer = pl.Trainer(
    max_epochs = tc.max_epochs,
    accelerator = "gpu",
    precision = "bf16-mixed",
    fast_dev_run = False,
    accumulate_grad_batches = 16, 
    logger = pl.loggers.TensorBoardLogger(os.path.join("tmp_ezh2","logs"), name = "ezh2"),
    val_check_interval = 128,
    limit_val_batches = 128,
    log_every_n_steps = 1,
    callbacks = [ callback_ckpt, pl.callbacks.LearningRateMonitor() ],
)
trainer.fit(pl_module, data_module)
pl_module.save_ckpt(os.path.join("tmp_ezh2", "ezh2.ckpt"))

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name  | Type             | Params
-------------------------------------------
0 | model | ChromBERTGeneral | 62.8 M
-------------------------------------------
18.9 M    Trainable params
43.9 M    Non-trainable params
62.8 M    Total params
251.013   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


## Use tuned to get regulator embedding

ChromBERT is now fine-tuned! You can access the tuned model directly using `pl_module.model`. However, please note that due to specific settings in flash-attention, you cannot change the dropout probability, which may introduce some randomness in the output.

For consistent results, we recommend saving the checkpoint and loading it with the original model to ensure you have the tuned model.

In [14]:
model_tuned = chrombert.get_preset_model_config(
    "general", 
    dropout = 0,
    ignore = True, ignore_index =ignore_index,   # ignore_index from above
    finetune_ckpt = os.path.abspath(os.path.join("tmp_ezh2", "ezh2.ckpt")) # use absolute path here, to avoid mixing of preset
).init_model()
# or use model_tuned = pl_module.model
summary(model_tuned, depth = 2)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
update path: finetune_ckpt = /home/yangdongxu/repos/ChromBERT_reorder/examples/tutorials/tmp_ezh2/ezh2.ckpt
use organisim hg38; max sequence length including cls is 6392
Loading checkpoint from /home/yangdongxu/repos/ChromBERT_reorder/examples/tutorials/tmp_ezh2/ezh2.ckpt
Loaded 110/110 parameters


Layer (type:depth-idx)                                  Param #
ChromBERTGeneral                                        --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               4,916,736
│    └─ModuleList: 2-2                                  51,978,240
├─GeneralHeader: 1-2                                    --
│    └─CistromeEmbeddingManager: 2-3                    --
│    └─Conv2d: 2-4                                      769
│    └─ReLU: 2-5                                        --
│    └─ResidualBlock: 2-6                               3,230,720
│    └─ResidualBlock: 2-7                               2,166,528
│    └─ResidualBlock: 2-8                               460,032
│    └─Linear: 2-9                                      257
Total params: 62,753,282
Trainable params: 62,753,282
Non-trainable params: 0

Then we can get the embedding manager following the instruction of tutorials about extracting embeddings.

In [15]:
model_emb = model_tuned.get_embedding_manager().cuda()
summary(model_emb)

Layer (type:depth-idx)                                  Param #
ChromBERTEmbedding                                      --
├─ChromBERT: 1-1                                        --
│    └─BERTEmbedding: 2-1                               --
│    │    └─TokenEmbedding: 3-1                         7,680
│    │    └─PositionalEmbedding: 3-2                    4,909,056
│    │    └─Dropout: 3-3                                --
│    └─ModuleList: 2-2                                  --
│    │    └─EncoderTransformerBlock: 3-4                6,497,280
│    │    └─EncoderTransformerBlock: 3-5                6,497,280
│    │    └─EncoderTransformerBlock: 3-6                6,497,280
│    │    └─EncoderTransformerBlock: 3-7                6,497,280
│    │    └─EncoderTransformerBlock: 3-8                6,497,280
│    │    └─EncoderTransformerBlock: 3-9                6,497,280
│    │    └─EncoderTransformerBlock: 3-10               6,497,280
│    │    └─EncoderTransformerBlock: 3-11          

In [16]:
dc_test = data_module.test_config
ds_test = dc_test.init_dataset()
dl_test = dc_test.init_dataloader(batch_size = 1)
len(ds_test), list(ds_test[0].keys())

(1101,
 ['input_ids',
  'position_ids',
  'region',
  'build_region_index',
  'ignore_index',
  'label'])

In [17]:
embs_classicial = []
embs_nonclassicial = []
for batch in tqdm(dl_test):
    with torch.no_grad():
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.cuda()
        emb = model_emb(batch)
    if batch["label"].item() == 1:
        embs_classicial.append(emb)
    else:
        embs_nonclassicial.append(emb)

print(len(embs_classicial), len(embs_nonclassicial))
embs_classicial = torch.cat(embs_classicial, dim = 0).cpu().numpy().mean(axis = 0)
embs_nonclassicial = torch.cat(embs_nonclassicial, dim = 0).cpu().numpy().mean(axis = 0)
embs_classicial.shape, embs_nonclassicial.shape

100%|██████████| 1101/1101 [01:06<00:00, 16.56it/s]


573 528


((1063, 768), (1063, 768))

We consider only factors below, remove histone modifications and chromatin accessibility. Then we can get similarity between factors, which represent the potential interactions between factors. 


In [18]:
with open(os.path.join(basedir, "config","hg38_6k_factors_list.txt"),"r") as f:
    factors = f.read().strip().split("\n")
factors = [f.strip().lower() for f in factors]
factors[:3], len(factors)

(['adnp', 'aebp2', 'aff1'], 992)

In [19]:
indices = np.in1d(model_emb.list_regulator,factors)
names = np.array(model_emb.list_regulator)[indices]
embs_classicial = embs_classicial[indices]
embs_nonclassicial = embs_nonclassicial[indices]


In [20]:
from sklearn.metrics.pairwise import cosine_similarity
cos_classicial_matrix = cosine_similarity(embs_classicial)
cos_nonclassicial_matrix = cosine_similarity(embs_nonclassicial)
df_cos_classicial = pd.DataFrame(cos_classicial_matrix, columns = names, index = names)
df_cos_nonclassicial = pd.DataFrame(cos_nonclassicial_matrix, columns = names, index = names)

In [21]:
# we define threshold to select the most related regulators pairs
thre_class = np.percentile(cos_classicial_matrix.flatten(), 95)
thre_nonclass = np.percentile(cos_nonclassicial_matrix.flatten(), 95)
thre_class, thre_nonclass

(0.6016223430633545, 0.5621198415756226)

now we find TRNs related to EZH2 non-classical functions. As you see, factors related to EZH2 classical functions is polycomb complex related, e.g. SUZ12. However, factors related to EZH2 non-classical functions tends to be associated with transcriptional activation, e.g. EP300, TP53.

In [22]:
df_cos_ezh2 = pd.DataFrame(index =names, data = {"classical":df_cos_classicial.loc["ezh2",:],"nonclassical":df_cos_nonclassicial.loc["ezh2",:]})
df_cos_ezh2["diff"] = df_cos_ezh2["classical"] - df_cos_ezh2["nonclassical"]
df_cos_ezh2

Unnamed: 0,classical,nonclassical,diff
adnp,0.270390,0.388505,-0.118115
aebp2,0.232409,0.334547,-0.102139
aff1,0.447353,0.485739,-0.038386
aff4,0.403353,0.492727,-0.089375
ago1,0.282023,0.288040,-0.006017
...,...,...,...
zscan5a,0.268473,0.325382,-0.056909
zta,0.257478,0.291565,-0.034087
zxdb,0.250267,0.270615,-0.020347
zxdc,0.218339,0.287684,-0.069345


In [23]:
df_cos_ezh2.query("classical > @thre_class ").sort_values("diff", ascending = False).head(10)

Unnamed: 0,classical,nonclassical,diff
ezh1,0.627185,0.576553,0.05063266
suz12,0.877764,0.832076,0.04568756
bcor,0.602734,0.56039,0.04234409
jarid2,0.61212,0.588507,0.02361345
ezh2,1.0,1.0,-2.384186e-07
rnf2,0.704727,0.709211,-0.004483521


In [24]:
df_cos_ezh2.query("nonclassical > @thre_nonclass ").sort_values("diff", ascending = True).head(10)

Unnamed: 0,classical,nonclassical,diff
foxm1,0.450502,0.606847,-0.156345
hinfp,0.446726,0.583847,-0.137121
med1,0.440611,0.574923,-0.134312
brca1,0.465605,0.587078,-0.121473
rela,0.447773,0.568249,-0.120476
stat3,0.503793,0.62012,-0.116327
ep300,0.477905,0.591606,-0.113701
tp53,0.465553,0.57892,-0.113366
e2f4,0.450022,0.562298,-0.112277
stat1,0.469702,0.570898,-0.101195


## The end
This tutorial provides a comprehensive guide to locus-specific TRN inference using EZH2' co-association as exmaple.   
We hope you find it helpful and informative.   
If you have any questions or need further assistance, please feel free to ask. Thank you for reading!