In [38]:
import torch
import sys
import pandas as pd
from tqdm import tqdm
import pickle

sys.path.append('DNABERT/')

from src.transformers import DNATokenizer 
from transformers import BertModel, BertConfig

In [None]:
dir_to_pretrained_model = "/s/project/mll/sergey/effect_prediction/MLM/dnabert/default/6-new-12w-0/"

config = BertConfig.from_pretrained('https://raw.githubusercontent.com/jerryji1993/DNABERT/master/src/transformers/dnabert-config/bert-config-6/config.json')
tokenizer = DNATokenizer.from_pretrained('dna6')
model = BertModel.from_pretrained(dir_to_pretrained_model, config=config)

In [13]:
def kmers_stride1(seq, k=6):
    # splits a sequence into overlapping k-mers
    return [seq[i:i + k] for i in range(0, len(seq)-k+1)] 

In [56]:
def get_embedding(seq):

    seq_kmer = kmers_stride1(seq)

    with torch.no_grad():
        model_input = tokenizer.encode_plus(seq_kmer, add_special_tokens=True, padding='max_length', max_length=512)["input_ids"]
        model_input = torch.tensor(model_input, dtype=torch.long)
        model_input = model_input.unsqueeze(0)   # to generate a fake batch with batch size one

        output = model(model_input)
        
    return output[1][0].cpu().numpy()

In [57]:
data_dir = '/s/project/mll/sergey/effect_prediction/MLM/clinvar/dnabert/default/'

In [58]:
seqs_df = pd.read_csv(data_dir + 'seqs.csv')

In [59]:
all_embd = {}

for _, var in tqdm(seqs_df.iterrows(), total=len(seqs_df)):
    emb_ref = get_embedding(var.refseq)
    emb_alt = get_embedding(var.altseq)
    all_embd[var.var_id] = (emb_ref,emb_alt)

100%|██████████| 21270/21270 [2:58:10<00:00,  1.99it/s]  


In [60]:
with open(data_dir + 'embeddings.pickle', 'wb') as f:
    pickle.dump(all_embd,f)

In [62]:
print('Done')

Done
