In [16]:
from transformers.models.bert.configuration_bert import BertConfig
from transformers import AutoTokenizer, AutoModel
import torch

tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True, config=config)

Some weights of BertModel were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
hidden_states = model(inputs)[0] # [1, sequence_length, 768]

# embedding with mean pooling
embedding_mean = torch.mean(hidden_states[0], dim=0)
print(embedding_mean.shape) # expect to be 768

# embedding with max pooling
embedding_max = torch.max(hidden_states[0], dim=0)[0]
print(embedding_max.shape) # expect to be 768


torch.Size([768])
torch.Size([768])


In [18]:
import pandas as pd

data_path = "./data/data_with_human_TE_cellline_all_plain.csv"
df = pd.read_csv(data_path, delimiter="\t")
df.head()

Unnamed: 0,SYMBOL,transcript_id,gene_id,tx_size,utr5_size,cds_size,utr3_size,tx_sequence,bio_source_108T,bio_source_12T,...,struct_max_stem_len_UTR5,struct_max_loop_len_UTR5,struct_min_dG_CDS,struct_n_hairpins_CDS,struct_n_bifurc_CDS,struct_n_bulges_CDS,struct_start_stem_CDS,struct_max_stem_len_CDS,struct_max_loop_len_CDS,fold
0,SAMD11,ENST00000342066.8,ENSG00000187634.12,2557,90,2046,421,GCAGAGCCCAGCAGATCCCTGCGGCGTTCGCGAGGGTGGGACGGGA...,-3.644472,0.98672,...,10.0,5.0,-13.9,2.0,1.0,0.0,9.0,4.0,4.0,4
1,NOC2L,ENST00000327044.7,ENSG00000188976.11,2757,16,2250,491,GCTTCGGGTTGGTGTCATGGCAGCTGCGGGGAGCCGCAAGAGGCGC...,1.06019,0.701399,...,10.0,6.0,-24.5,1.0,0.0,0.0,21.0,10.0,6.0,8
2,KLHL17,ENST00000338591.8,ENSG00000187961.14,2567,110,1929,528,GGGAGTGAGCGACACAGAGCGGGCCGCCACCGCCGAGCAGCCCTCC...,-1.198005,-1.178952,...,10.0,4.0,-23.0,1.0,1.0,1.0,3.0,9.0,5.0,9
3,HES4,ENST00000304952.11,ENSG00000188290.11,885,124,666,95,GCGGGCCTGGAGCCGGGATCCGCCCTAGGGGCTCGGATCGCCGCGC...,-1.1074,0.158079,...,16.0,3.0,-24.8,1.0,0.0,0.0,3.0,11.0,3.0,7
4,ISG15,ENST00000649529.1,ENSG00000187608.10,637,77,498,62,GGCGGCTGAGAGGCAGCGAACTCATCTTTGCCAGTACAGGAGCTTG...,0.631561,2.013887,...,8.0,5.0,-28.6,1.0,0.0,2.0,2.0,13.0,5.0,2


In [19]:
#number of unique cell lines -- number of columns that have "bio_source" in the name

print("Number of columns: ", len(df.columns))
print("Number of rows: ", len(df))

na_rows = df[df.isna().any(axis=1)]
print("Number of rows that have NA: ", len(na_rows))

bio_source_cols = [col for col in df.columns if 'bio_source' in col]
print(f"Number of unique human cell lines: {len(bio_source_cols)}")

Number of columns:  102
Number of rows:  11153
Number of rows that have NA:  354
Number of unique human cell lines: 78


In [20]:
#drop the rows that have null values
df = df.dropna()
na_rows = df[df.isna().any(axis=1)]
print("Number of rows that have NA: ", len(na_rows))
print("Number of rows: ", len(df))
df_filtered = df[df['tx_sequence'].apply(len).between(500, 1500)].reset_index(drop=True)
df.head()

Number of rows that have NA:  0
Number of rows:  10799


Unnamed: 0,SYMBOL,transcript_id,gene_id,tx_size,utr5_size,cds_size,utr3_size,tx_sequence,bio_source_108T,bio_source_12T,...,struct_max_stem_len_UTR5,struct_max_loop_len_UTR5,struct_min_dG_CDS,struct_n_hairpins_CDS,struct_n_bifurc_CDS,struct_n_bulges_CDS,struct_start_stem_CDS,struct_max_stem_len_CDS,struct_max_loop_len_CDS,fold
0,SAMD11,ENST00000342066.8,ENSG00000187634.12,2557,90,2046,421,GCAGAGCCCAGCAGATCCCTGCGGCGTTCGCGAGGGTGGGACGGGA...,-3.644472,0.98672,...,10.0,5.0,-13.9,2.0,1.0,0.0,9.0,4.0,4.0,4
1,NOC2L,ENST00000327044.7,ENSG00000188976.11,2757,16,2250,491,GCTTCGGGTTGGTGTCATGGCAGCTGCGGGGAGCCGCAAGAGGCGC...,1.06019,0.701399,...,10.0,6.0,-24.5,1.0,0.0,0.0,21.0,10.0,6.0,8
2,KLHL17,ENST00000338591.8,ENSG00000187961.14,2567,110,1929,528,GGGAGTGAGCGACACAGAGCGGGCCGCCACCGCCGAGCAGCCCTCC...,-1.198005,-1.178952,...,10.0,4.0,-23.0,1.0,1.0,1.0,3.0,9.0,5.0,9
3,HES4,ENST00000304952.11,ENSG00000188290.11,885,124,666,95,GCGGGCCTGGAGCCGGGATCCGCCCTAGGGGCTCGGATCGCCGCGC...,-1.1074,0.158079,...,16.0,3.0,-24.8,1.0,0.0,0.0,3.0,11.0,3.0,7
4,ISG15,ENST00000649529.1,ENSG00000187608.10,637,77,498,62,GGCGGCTGAGAGGCAGCGAACTCATCTTTGCCAGTACAGGAGCTTG...,0.631561,2.013887,...,8.0,5.0,-28.6,1.0,0.0,2.0,2.0,13.0,5.0,2


In [21]:
from tqdm import tqdm
embedding_means = []
embedding_maxs = []
transcript_ids = []

# Loop over all sequences
for idx, row in tqdm(df_filtered.iterrows(), total=len(df_filtered), desc="Generating embeddings"):
    transcript_id = row['transcript_id']
    dna_sequence = row['tx_sequence']

    with torch.no_grad():  # No gradients needed
        inputs = tokenizer(dna_sequence, return_tensors='pt')["input_ids"]
        hidden_states = model(inputs)[0]  # [1, sequence_length, 768]

        # Mean pooling
        embedding_mean = torch.mean(hidden_states[0], dim=0)  # [768]
        
        # Max pooling
        embedding_max = torch.max(hidden_states[0], dim=0)[0]  # [768]

    # Save to lists
    transcript_ids.append(transcript_id)
    embedding_means.append(embedding_mean.cpu().numpy())
    embedding_maxs.append(embedding_max.cpu().numpy())

    print(f"Transcript generated for {transcript_id}")


# Create a new DataFrame
embeddings_df = pd.DataFrame({
    'transcript_id': transcript_ids,
    'embedding_mean': embedding_means,
    'embedding_max': embedding_maxs
})

# Save to a file (e.g., a CSV or a pickle file)
embeddings_df.to_pickle('transcript_embeddings.pkl')

Generating embeddings:   0%|          | 0/1267 [00:26<?, ?it/s]


KeyboardInterrupt: 