# Example for causal eQTL identification using prompt-enhanced ChromBERT
ChromBERT’s integration of a DNA sequence prompt from DNABERT-2 enhances its versatility for genomic applications, such as fine-mapping eQTLs.   
By incorporating DNA sequence variation, ChromBERT can identify causal variants influencing gene expression.   
Using the latest eQTL Catalogue, we fine-tuned ChromBERT to classify causal and non-causal variants as an example of its capability.

**Attention: You should go through this [tutorial](https://chrombert.readthedocs.io/en/latest/tutorial_finetuning_ChromBERT.html) at first to get familiar with the basic usage of ChromBERT.**

In [1]:
import  os 
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # to selected gpu used 

import sys 
import pathlib 
import pickle

import torch 
import numpy as np 
import pandas as pd 
from tqdm import tqdm 
from matplotlib import pyplot as plt
import seaborn as sns
import chrombert
from torchinfo import summary

import lightning.pytorch as pl

import sklearn 
from sklearn import metrics

basedir = os.path.expanduser("~/.cache/chrombert/data"  )


  from pandas.core import (


## Prepare datasets

We provide a demo eQTL dataset for the lung, which includes additional columns compared to those used in other tasks: `base_ref`, `base_alt`, `variant_id`, `label`, and `pos`.

- `pos`: Specifies the variant position.  
- `base_ref`: Indicates the reference allele.  
- `base_alt`: Represents the alternative allele.  

The `variant_id` serves as a unique identifier for each variant and can be any unique string for convenience. The `label` column classifies variants as causal (1) or non-causal (0). If the dataset is used only for prediction, the `label` column can be omitted.

In [2]:
table_eqtls = os.path.join(basedir, "demo","eqtl", "lung_eqtl.tsv")
!head $table_eqtls

chrom	start	end	build_region_index	base_ref	base_alt	variant_id	label	pos
chr17	29595000	29596000	771566	A	G	chr17_29595257_A_G	1	29595257
chr11	78139000	78140000	342669	C	T	chr11_78139778_C_T	1	78139778
chr19	35309000	35310000	898784	G	T	chr19_35309759_G_T	1	35309759
chr4	6695000	6696000	1359650	T	C	chr4_6695946_T_C	1	6695946
chr6	139029000	139030000	1719458	G	A	chr6_139029045_G_A	1	139029045
chr15	41554000	41555000	640548	T	G	chr15_41554978_T_G	1	41554978
chr16	4343000	4344000	694493	A	C	chr16_4343375_A_C	1	4343375
chr9	97238000	97239000	2025247	A	G	chr9_97238449_A_G	1	97238449
chr5	96757000	96758000	1551457	G	A	chr5_96757518_G_A	1	96757518


In [3]:
from sklearn.model_selection import train_test_split

odir = pathlib.Path("tmp_eqtl")
odir.mkdir(exist_ok=True, parents=True)

df_full = pd.read_csv(table_eqtls, sep="\t")
df_train, df_test = train_test_split(df_full, test_size=0.4, random_state=42)
df_train.to_csv(odir / "train.tsv", sep="\t", index=False)
df_test.to_csv(odir / "test.tsv", sep="\t", index=False)
len(df_train), len(df_test)

(921, 614)

In [4]:
# configure dataset
dc = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = table_eqtls)
print(dc)

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa
{
    "hdf5_file": "/home/yangdongxu/.cache/chrombert/data/hg38_6k_1kb.hdf5",
    "supervised_file": "/home/yangdongxu/.cache/chrombert/data/demo/eqtl/lung_eqtl.tsv",
    "kind": "PromptDataset",
    "meta_file": "/home/yangdongxu/.cache/chrombert/data/config/hg38_6k_meta.json",
    "ignore": false,
    "ignore_object": null,
    "batch_size": 8,
    "num_workers": 20,
    "shuffle": false,
    "perturbation": false,
    "perturbation_object": null,
    "perturbation_value": 0,
    "prompt_kind": "dna",
    "prompt_regulator": null,
    "prompt_regulator_cache_file": null,
    "prompt_celltype": null,
    "prompt_celltype_cache_file": null,
    "fasta_file": "/home/yangdongxu/.cache/chrombert/data/other/hg38.fa",
    "flank_window": 0
}


In [5]:
# initialize dataset
ds = dc.init_dataset()
ds[1]

{'input_ids': tensor([9, 9, 9,  ..., 6, 6, 6], dtype=torch.int8),
 'position_ids': tensor([   1,    2,    3,  ..., 6389, 6390, 6391]),
 'region': tensor([      11, 78139000, 78140000], dtype=torch.int32),
 'build_region_index': 342669,
 'label': 1,
 'seq_raw': 'GATTTGTTCCAAATCAGACAGCGCCAGGTCTGAACCTAGCCAGCTGGGGCTAAGTCAAGTAACAACTGGCGAAACAGAAAGCTTAGCAAAGGCAGGATAGCGACAAACACGACCTAAAGTTTTCTCTTCATACCCAGGGATATCCACACCTTTCTCTCCCGCCCTGACCGACCGCGGGGCCTCCCCGCCCAGCCCCTGGCCGTGCGAGTCCCTTACTATGTGGGGATGAGAAGGCATTTGAGAAGAGTCACCCCGAGCGCCAAAGCCGAAAACCAATTGCCAGTACCCGTGGCAATTGTGAGCGCCGCCATTGCTGCGGCACCGCACGCTTCCCACCAACTTGATCCACATCCGGGATCCCGCGCATGCGGAGAAAGCCCTCTGAAGCCGTGCCCGCTAGCTGCGCGCATGCGGCGAGCGGCGCAGCCAGTCCGGGGACTGCAGTCAGCTATTTAAACCTCCCGCCCACCTTTTCTTTAGACCCGCGTCTCACCCCGGGCCGGAAGGGCTCCTGCGCAGGCGTTTGTAGCCACTTTTAAGTTTTATCAGCTAGTTCATGCTTGCGTTGAAAGAGTGGTCGTTTGCGCTGGGTCATCACTGTGTAGTATTGGGGATACTTAGGTGAGAAAAAAACTTAACGCTAGAGACGTTCACGCACTAGTGGAGAAGCCAGGATTGTTGCCCTAGAGTTACAGTAGATAAAAGTACCTCAGAGAACTGCGGGGGCTCCCAACCTGG

## Prepare the model

The model can be loaded in the same way as for other tasks, except that the kind is set to `prompt_dna`.

In [6]:
mc = chrombert.get_preset_model_config(
    "prompt_dna",
    dnabert2_ckpt="zhihan1996/DNABERT-2-117M" # use model from hugging-face, or provide path directly
    )
print(mc)

update path: mtx_mask = config/hg38_6k_mask_matrix.tsv
update path: pretrain_ckpt = checkpoint/hg38_6k_1kb_pretrain.ckpt
{
    "genome": "hg38",
    "task": "prompt",
    "dim_output": 1,
    "mtx_mask": "/home/yangdongxu/.cache/chrombert/data/config/hg38_6k_mask_matrix.tsv",
    "dropout": 0.1,
    "pretrain_ckpt": "/home/yangdongxu/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt",
    "finetune_ckpt": null,
    "ignore": false,
    "ignore_index": [
        null,
        null
    ],
    "gep_flank_window": 4,
    "gep_parallel_embedding": false,
    "gep_gradient_checkpoint": false,
    "gep_zero_inflation": true,
    "prompt_kind": "dna",
    "prompt_dim_external": 768,
    "dnabert2_ckpt": "zhihan1996/DNABERT-2-117M"
}


In [7]:
model = mc.init_model()
summary(model)

use organisim hg38; max sequence length including cls is 6392


Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to b

Layer (type:depth-idx)                                            Param #
ChromBERTPromptDNA                                                --
├─ChromBERT: 1-1                                                  --
│    └─BERTEmbedding: 2-1                                         --
│    │    └─TokenEmbedding: 3-1                                   7,680
│    │    └─PositionalEmbedding: 3-2                              4,909,056
│    │    └─Dropout: 3-3                                          --
│    └─ModuleList: 2-2                                            --
│    │    └─EncoderTransformerBlock: 3-4                          6,497,280
│    │    └─EncoderTransformerBlock: 3-5                          6,497,280
│    │    └─EncoderTransformerBlock: 3-6                          6,497,280
│    │    └─EncoderTransformerBlock: 3-7                          6,497,280
│    │    └─EncoderTransformerBlock: 3-8                          6,497,280
│    │    └─EncoderTransformerBlock: 3-9             

In [8]:
# we freeze DNABERT-2 model during traing 
model.dnabert2.freeze()
model.display_trainable_parameters()

{'total_params': 191152515, 'trainable_params': 74083971}
pretrain_model.embedding.token.weight : trainable
pretrain_model.embedding.position.pe.pe.weight : trainable
pretrain_model.transformer_blocks.0.attention.Wqkv.weight : trainable
pretrain_model.transformer_blocks.0.attention.Wqkv.bias : trainable
pretrain_model.transformer_blocks.0.feed_forward.w_1.weight : trainable
pretrain_model.transformer_blocks.0.feed_forward.w_1.bias : trainable
pretrain_model.transformer_blocks.0.feed_forward.w_2.weight : trainable
pretrain_model.transformer_blocks.0.feed_forward.w_2.bias : trainable
pretrain_model.transformer_blocks.0.input_sublayer.norm.a_2 : trainable
pretrain_model.transformer_blocks.0.input_sublayer.norm.b_2 : trainable
pretrain_model.transformer_blocks.0.output_sublayer.norm.a_2 : trainable
pretrain_model.transformer_blocks.0.output_sublayer.norm.b_2 : trainable
pretrain_model.transformer_blocks.1.attention.Wqkv.weight : trainable
pretrain_model.transformer_blocks.1.attention.Wqkv.

{'total_params': 191152515, 'trainable_params': 74083971}

## Fine-tune

This task has fewer training samples compared to others, so we use a reduced number of training steps. To simplify the process, the model is trained directly without using PyTorch Lightning.

In [9]:
dc_train = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = df_train)
dc_test = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = df_test)

ds_train = dc_train.init_dataset()
ds_test = dc_test.init_dataset()
dl_train = dc_train.init_dataloader(batch_size=2, shuffle=True, num_workers=4)
dl_test = dc_test.init_dataloader(batch_size=2, shuffle=False, num_workers=4)


update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa
update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa


In [10]:
from transformers import get_linear_schedule_with_warmup 
from torch import nn
 
def train(m, dl, lr = 5e-5, grad_accumulation_steps=4, min_epochs = 2, max_epochs = 5, max_steps = 200):
    num_training_steps = len(dl) * max_epochs // grad_accumulation_steps
    optimizer = torch.optim.AdamW(m.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
    loss_f = nn.BCEWithLogitsLoss()

    keys_to_cuda = ["input_ids", "position_ids", "label"]
    total_steps = 0
    for e in range(max_epochs):
        m.train()
        for i, batch in enumerate(tqdm(dl)):
            # batch = {k: v.cuda() for k, v in batch.items()}
            for k in keys_to_cuda:
                batch[k] = batch[k].cuda()

            logits = m(batch).view(-1)
            loss = loss_f(logits, batch["label"].to(torch.float))
            loss = loss.mean() / grad_accumulation_steps
            loss.backward()
            if (i+1) % grad_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                total_steps += 1
                if total_steps > max_steps:
                    if e >= min_epochs:
                        return m

        print(f"epoch {e} loss {loss.item()}")
    return m

In [11]:
model_tuned = train(model.cuda().bfloat16(), dl_train)

  0%|          | 0/461 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 461/461 [01:18<00:00,  5.89it/s]


epoch 0 loss 0.032958984375


100%|██████████| 461/461 [01:16<00:00,  6.00it/s]


epoch 1 loss 0.0133056640625


  1%|          | 3/461 [00:01<03:13,  2.37it/s]


## Evaluation  

Typically, we save the checkpoint and evaluate it later. However, direct evaluation is an option if minimal randomness is acceptable. 

In [12]:
model_tuned.save_ckpt(odir / "model.ckpt")


In [13]:
model = mc.init_model(finetune_ckpt = odir / "model.ckpt", dropout=0).cuda().bfloat16()
model.display_trainable_parameters(verbose=False)

use organisim hg38; max sequence length including cls is 6392


Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to b

Loading checkpoint from tmp_eqtl/model.ckpt
Loaded 290/290 parameters
{'total_params': 191152515, 'trainable_params': 191152515}


{'total_params': 191152515, 'trainable_params': 191152515}

In [14]:
model = model.cuda().bfloat16()
probs = []
labels = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(tqdm(dl_test)):
        keys_to_cuda = ["input_ids", "position_ids", "label"]
        for k in keys_to_cuda:
            batch[k] = batch[k].cuda()
        logits = model(batch).view(-1)
        probs.append(logits.sigmoid().float().cpu().numpy())
        labels.append(batch["label"].cpu().numpy())

probs = np.concatenate(probs)
labels = np.concatenate(labels)
auc = metrics.roc_auc_score(labels, probs)
aupr = metrics.average_precision_score(labels, probs)
print(f"test auc {auc}, aupr {aupr}")

  0%|          | 0/307 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 307/307 [00:19<00:00, 15.54it/s]

test auc 0.8792105598504799, aupr 0.8742700445408895





## Identification of differential regulators

We can infer differential regulators for causal and non-causal variants by comparing distances in regulator embeddings. Notably, DNase-seq data often shows a significant difference in these distances.


In [15]:
model_emb = model.get_embedding_manager()
model_emb

ChromBERTEmbedding(
  (pretrain_model): ChromBERT(
    (embedding): BERTEmbedding(
      (token): TokenEmbedding(10, 768, padding_idx=0)
      (position): PositionalEmbedding(
        (pe): PositionalEmbeddingTrainable(
          (pe): Embedding(6392, 768)
        )
      )
      (dropout): Dropout(p=0, inplace=False)
    )
    (transformer_blocks): ModuleList(
      (0-7): 8 x EncoderTransformerBlock(
        (attention): SelfAttentionFlashMHA(
          (Wqkv): Linear(in_features=768, out_features=2304, bias=True)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=768, out_features=3072, bias=True)
          (w_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0, inplace=False)
          (activation): GELU()
        )
        (input_sublayer): SublayerConnection(
          (norm): LayerNorm()
          (dropout): Dropout(p=0, inplace=False)
        )
        (output_sublayer): SublayerConnection(
     

In [16]:
dc = chrombert.get_preset_dataset_config("prompt_dna", supervised_file = table_eqtls)
dl = dc.init_dataloader(batch_size = 2, shuffle=False, num_workers=4)
labels = []
embeddings = []
for batch in tqdm(dl):
    keys_to_cuda = ["input_ids", "position_ids", "label"]
    for k in keys_to_cuda:
        batch[k] = batch[k].cuda()
    labels.append(batch["label"].cpu().numpy())
    embeddings.append(model_emb(batch).float().cpu().numpy())
labels = np.concatenate(labels)
embeddings = np.concatenate(embeddings)
emb_causal = embeddings[labels == 1].mean(axis=0)
emb_noncausal = embeddings[labels == 0].mean(axis=0)
emb_causal.shape, emb_noncausal.shape

update path: hdf5_file = hg38_6k_1kb.hdf5
update path: meta_file = config/hg38_6k_meta.json
update path: fasta_file = other/hg38.fa


100%|██████████| 768/768 [00:34<00:00, 22.00it/s]


((1064, 768), (1064, 768))

In [17]:
df_shift = pd.DataFrame(
    {
        "regulator":model_emb.list_regulator,
        "shift": 1- sklearn.metrics.pairwise.cosine_similarity(emb_causal, emb_noncausal).diagonal(),
    }
).sort_values("shift", ascending=False, ignore_index=True)

df_shift.head(20)

Unnamed: 0,regulator,shift
0,h3k4me3,0.213028
1,dnase,0.21297
2,faire,0.207958
3,h3k4me2,0.199916
4,maz,0.197446
5,kdm2b,0.19714
6,chd8,0.196293
7,zfx,0.193517
8,gabpa,0.191083
9,nrf1,0.190021
