In [1]:
import numpy as np
import pandas as pd
import warnings
import pickle

from collections import defaultdict

import os

import torch 
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel, BertForMaskedLM

#from DNABERT2.bert_layers import BertModel as DNABERT2
from DNABERT.src.transformers.tokenization_dna import DNATokenizer

import helpers.misc as misc

%load_ext autoreload
%autoreload 2

In [2]:
data_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/'

In [3]:
model_dirs = {
'dnabert':data_dir + 'models/whole_genome/dnabert/6-new-12w-0/',
'dnabert2':data_dir + 'models/whole_genome/dnabert2/DNABERT-2-117M/',
'ntrans-v2-100m':data_dir + 'models/whole_genome/nucleotide-transformer-v2-100m-multi-species',
'ntrans-v2-250m':data_dir + 'models/whole_genome/nucleotide-transformer-v2-250m-multi-species',
'ntrans-v2-500m':data_dir + 'models/whole_genome/nucleotide-transformer-v2-500m-multi-species',
'dnabert-3utr':data_dir + 'models/zoonomia-3utr/dnabert-3utr/checkpoints/epoch_30/',
'dnabert2-3utr':data_dir + 'models/zoonomia-3utr/dnabert2-3utr/checkpoints/epoch_18/',
'ntrans-v2-250m-3utr':data_dir + 'models/zoonomia-3utr/ntrans-v2-250m-3utr/checkpoints/epoch_23/',
'dnabert2-zoo':data_dir + 'models/zoonomia/dnabert2-z/checkpoints/chkpt_336/',
'dnabert-3utr-2e':data_dir + 'models/zoonomia-3utr/dnabert-3utr-2e/checkpoints/chkpt_40/',
'dnabert2-3utr-2e':data_dir + 'models/zoonomia-3utr/dnabert2-3utr-2e/checkpoints/chkpt_275/',
'ntrans-v2-100m-3utr-2e':data_dir + 'models/zoonomia-3utr/ntrans-v2-100m-3utr-2e/checkpoints/chkpt_633/',
'ntrans-v2-250m-3utr-2e':data_dir + 'models/zoonomia-3utr/ntrans-v2-250m-3utr-2e/checkpoints/chkpt_56/',
}

In [39]:
input_params = misc.dotdict({})

input_params.fasta = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/mpra/siegel_2022/fasta/variants_dna_fwd.fa'

input_params.model = 'dnabert'

input_params.output_dir = './test/'

input_params.batch_size = 30

input_params.include_txt = None 

input_params.N_folds = None

input_params.fold=None

In [40]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    cuda_device_name = torch.cuda.get_device_name(0)
    print(f'\nCUDA device: {cuda_device_name}\n')
else:
    device = torch.device('cpu')
    print('\nCUDA device: CPU\n')
    #raise Exception('CUDA is not found')


CUDA device: Tesla V100S-PCIE-32GB



In [41]:
max_length = {'dnabert':510,
              'dnabert2':512,'dnabert2-zoo':1024,
              'ntrans-v2-100m':1024, 'ntrans-v2-250m':1024,'ntrans-v2-500m':1024,
              'dnabert-3utr':510,'dnabert2-3utr':1024,'ntrans-v2-250m-3utr':1024,
              'dnabert-3utr-2e':510, 'dnabert2-3utr-2e':1024, 'ntrans-v2-250m-3utr-2e':1024, 'ntrans-v2-100m-3utr-2e':1024, 
              }

In [42]:
def load_model(model_name):

    print(f'Loading model {model_name} from {model_dirs[model_name]}')

    if  'dnabert' in model_name and not 'dnabert2' in model_name:
        
        tokenizer = DNATokenizer(vocab_file='./DNABERT/src/transformers/dnabert-config/bert-config-6/vocab.txt',max_len=510)
        model = BertForMaskedLM.from_pretrained(model_dirs[model_name]).to(device);

    elif 'dnabert2' in model_name:

        tokenizer = AutoTokenizer.from_pretrained(model_dirs[model_name],trust_remote_code=True)
        embeddings_model = AutoModel.from_pretrained(model_dirs[model_name],trust_remote_code=True).to(device);
        prediction_model = BertForMaskedLM.from_pretrained(model_dirs[model_name]).to(device);
        model = (embeddings_model, prediction_model)
    
    elif 'ntrans' in model_name:

        # Import the tokenizer and the model
        tokenizer = AutoTokenizer.from_pretrained(model_dirs[model_name],trust_remote_code=True)
        model = AutoModelForMaskedLM.from_pretrained(model_dirs[model_name],trust_remote_code=True).to(device);

    return tokenizer, model

In [43]:
class SeqDataset(Dataset):
    
    def __init__(self, fasta_file, fold=None, N_folds=None, include_txt=None):
        
        seqs = defaultdict(str)
            
        with open(fasta_file, 'r') as f:
            for line in f:
                if line.startswith('>'):
                    transcript_id = line[1:].rstrip()
                else:
                    seqs[transcript_id] += line.rstrip().upper()
                    
        #seqs = {k:v[:MAX_SEQ_LENGTH] for k,v in seqs.items()}
        #seqs = {k:''.join(np.random.choice(list('ACGT'),size=MAX_LENGTH)) for k,v in seqs.items()}
        seqs = list(seqs.items())

        if include_txt!=None:
            print(f'Including sequences from {include_txt}')
            processed_seqs = pd.read_csv(include_txt,names=['seq_name']).seq_name.values
            seqs = [(seq_name,seq) for seq_name,seq in seqs if seq_name in processed_seqs]
        if N_folds!=None:
            print(f'Fold {fold}')
            folds = np.tile(np.arange(N_folds),len(seqs)//N_folds+1)[:len(seqs)]
            seqs = [x for idx,x in enumerate(seqs) if folds[idx]==fold]
            
        self.seqs = seqs
        self.max_length = max([len(seq[1]) for seq in self.seqs])
        
    def __len__(self):
        
        return len(self.seqs)
    
    def __getitem__(self, idx):
        
        return self.seqs[idx]

In [44]:
def kmers_stride1(seq, k=6):
    # splits a sequence into overlapping k-mers
    return [seq[i:i + k] for i in range(0, len(seq)-k+1)] 

In [69]:
def get_batch_embeddings(model_name, sequences):

    if input_params.max_tokenized_length is None:
        max_tokenized_length = max_length[model_name]
    else:
        max_tokenized_length = input_params.max_tokenized_length

    if 'dnabert' in model_name and not 'dnabert2' in model_name:

        mean_sequence_embeddings = []
        losses = []

        #special_token_ids = [tokenizer.pad_token_id, tokenizer.mask_token_id, tokenizer.sep_token_id, tokenizer.cls_token_id, tokenizer.unk_token_id]
        
        for seq in sequences:

            if len(seq)<6:
                emb_seq = np.zeros((1,768))
                emb_seq[:] = np.nan
                losses.append(np.nan)
                mean_sequence_embeddings.append(emb_seq)
                continue

            if len(seq)>max_tokenized_length:
                warnings.warn('Cutting out the central part of the sequence to center around DNABERT FOV')
                seq = misc.center_seq(seq,max_tokenized_length)

            seq_kmer = kmers_stride1(seq)
        
            inputs = tokenizer.encode_plus(seq_kmer,
                                                truncation = True,
                                                return_tensors = 'pt',  
                                                add_special_tokens=True, 
                                                max_length=512).to(device)

            torch_outs = model(inputs["input_ids"], 
                                labels = inputs["input_ids"],
                                output_hidden_states=True)
                
            emb_seq = torch_outs.hidden_states[-1].mean(dim=1).detach().cpu().numpy()
            
            mean_sequence_embeddings.append(emb_seq)

            attention_mask = inputs['attention_mask'].squeeze()
                
            loss = F.cross_entropy(torch_outs.logits.reshape((-1,torch_outs.logits.shape[-1])), 
                      inputs["input_ids"].reshape(-1), reduction='none')
        
            loss = (loss*attention_mask).sum()/attention_mask.sum()
            
            losses.append(loss.item())
            
            #probas = F.softmax(torch_outs['logits'],dim=2).detach().cpu().numpy()
            #token_ids = model_input.cpu().numpy()
            #gt_probas = np.take_along_axis(probas, token_ids[...,None], axis=2)
            #gt_probas = gt_probas[~np.isin(token_ids,special_token_ids)]
            #log_probas_seq = np.log(gt_probas).squeeze()
            #log_probas.append(log_probas_seq)

        return (np.vstack(mean_sequence_embeddings),np.array(losses))

    elif 'dnabert2' in model_name:
            
        inputs = tokenizer(sequences, 
                           truncation=True, 
                           return_tensors = 'pt', 
                           padding="max_length", 
                           max_length = max_tokenized_length).to(device)
        
        assert 'A100' in cuda_device_name, 'A100 GPU is required to generate embeddings for the DNABERT2 model'

        attention_mask = inputs["attention_mask"]

        hidden_states = model[0](inputs["input_ids"],
                                attention_mask=attention_mask)[0] # [1, sequence_length, 768]

        torch_outs = model[1](inputs["input_ids"], attention_mask=attention_mask)

        losses = F.cross_entropy(torch_outs.logits.reshape((-1,torch_outs.logits.shape[-1])), 
                      inputs["input_ids"].reshape(-1), reduction='none').reshape((len(sequences),-1))
        
        losses = (losses*attention_mask).sum(1)/attention_mask.sum(1)

        attention_mask = torch.unsqueeze(attention_mask, dim=-1)
        
        # Compute mean embeddings per sequence
        mean_sequence_embeddings = torch.sum(attention_mask*hidden_states, axis=-2)/torch.sum(attention_mask, axis=1)
        
        #mean_sequence_embeddings = torch.mean(hidden_states, dim=1)# embedding with mean pooling

        losses = losses.detach().cpu().numpy()
        mean_sequence_embeddings = mean_sequence_embeddings.detach().cpu().numpy()

        return (mean_sequence_embeddings, losses)

    elif 'ntrans' in model_name:

        inputs = tokenizer.batch_encode_plus(sequences, 
                                                      truncation = True,
                                                      return_tensors="pt", 
                                                      padding="max_length", 
                                                      max_length = max_tokenized_length).to(device)
                    
        torch_outs = model(
            inputs["input_ids"],
            attention_mask=inputs['attention_mask'],
            encoder_attention_mask=inputs['attention_mask'],
            output_hidden_states=True)
        
        embeddings = torch_outs['hidden_states'][-1]
        
        attention_mask = inputs['attention_mask']
                
        losses = F.cross_entropy(torch_outs.logits.reshape((-1,torch_outs.logits.shape[-1])), 
                      inputs["input_ids"].reshape(-1), reduction='none').reshape((len(sequences),-1))
        
        losses = (losses*attention_mask).sum(1)/attention_mask.sum(1)

        # Add embed dimension axis
        attention_mask = torch.unsqueeze(attention_mask, dim=-1)
        
        # Compute mean embeddings per sequence
        mean_sequence_embeddings = torch.sum(attention_mask*embeddings, axis=-2)/torch.sum(attention_mask, axis=1)
        #print(f"Mean sequence embeddings: {mean_sequence_embeddings}")

        #probas = F.softmax(torch_outs['logits'],dim=2).cpu().numpy()
        #inputs = inputs.cpu().numpy()
        #gt_probas = np.take_along_axis(probas, inputs[...,None], axis=2).squeeze()
        #log_probas = np.log(gt_probas)
        
        mean_sequence_embeddings = mean_sequence_embeddings.detach().cpu().numpy()
        losses = losses.detach().cpu().numpy()

    return (mean_sequence_embeddings, losses)

In [46]:
tokenizer, model = load_model(input_params.model)

Loading model dnabert from /lustre/groups/epigenereg01/workspace/projects/vale/mlm/models/whole_genome/dnabert/6-new-12w-0/


In [47]:
dataset = SeqDataset(input_params.fasta, fold=input_params.fold, N_folds=input_params.N_folds, include_txt=input_params.include_txt)

dataloader = DataLoader(dataset = dataset,
                        batch_size = input_params.batch_size,
                        num_workers = 2, collate_fn = None, shuffle = False)

In [48]:
#for model_name in ('dnabert','dnabert2', 'ntrans-v2-250m', 'dnabert-3utr','dnabert2-3utr', 'ntrans-v2-250m-3utr'):
#    print(model_name)
#    tokenizer, model = load_model(model_name)
#    seq_names,sequences = next(iter(dataloader))
#    embeddings, losses = get_batch_embeddings(model_name,sequences)
#    print(model, embeddings.shape, losses.shape)

In [49]:
#seq = ''.join(np.random.choice(list('ACGT'),size=100))

In [70]:
seq_names,sequences = next(iter(dataloader))
with torch.no_grad():
    embeddings, losses = get_batch_embeddings(input_params.model,sequences)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
torch.Size([153]) torch.Size([153])
torch.Size([157]) torch.Size([157])
torch.Size([157]) torch.Size([157])
torch.Size([157]) torch.Size([157])
torch.Size([157]) torch.Size([157])
torch.Size([157]) torch.Size([157])
torch.Size([157]) torch.Size([157])
torch.Size([157]) torch.Size([157])
torch.Size([157]) torch.Size([157])
torch.Size([157]) torch.Size([157])
torch.Size([157]) torch.Size([157])
torch.Size([157]) torch.Size([157])
to

In [14]:
all_emb, all_losses = [], []

for seq_idx, (seq_names,sequences) in enumerate(dataloader):

    #if os.path.isfile(input_params.output_dir + f'/{seq_names[0]}.pickle'):
    #    continue
        
    print(f'generating embeddings for batch {seq_idx}/{len(dataloader)}')
    
    with torch.no_grad():
        embeddings, losses = get_batch_embeddings(input_params.model,sequences)

    all_emb.append(embeddings)
    all_losses.append(losses)

    #with open(input_params.output_dir + f'/{seq_names[0]}.pickle', 'wb') as f:
    #    pickle.dump((seq_names,emb,logprobs),f)

if input_params.fold!=None:
    output_name = input_params.output_dir + f'/predictions_{input_params.fold}.pickle'
else:
    output_name = input_params.output_dir + '/predictions.pickle'

os.makedirs(input_params.output_dir, exist_ok=True)

seq_names, seqs = zip(*dataset.seqs)

with open(output_name, 'wb') as f:
    pickle.dump({'seq_names':seq_names, 'seqs':seqs, 'embeddings':np.vstack(all_emb), 
                 'losses':np.hstack(all_losses), 'fasta':input_params.fasta},f)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
generating embeddings for batch 0/605
generating embeddings for batch 1/605
generating embeddings for batch 2/605
generating embeddings for batch 3/605
generating embeddings for batch 4/605
generating embeddings for batch 5/605
generating embeddings for batch 6/605
generating embeddings for batch 7/605
generating embeddings for batch 8/605
generating embeddings for batch 9/605
generating embeddings for batch 10/605
generating embe


KeyboardInterrupt



In [114]:
with open('/lustre/groups/epigenereg01/workspace/projects/vale/mlm/mpra/griesemer_2021//embeddings/dnabert2-zoo/predictions.pickle','rb') as f:
    data=pickle.load(f)
print(data['seq_names'][:5])
data['embeddings']

('10_75883134_CTT_alt', '10_75883134_CTT_ref', '11_100864071_AT_alt', '11_100864071_AT_ref', '11_100908083_TA_alt')


array([[ 0.1843825 , -0.34444427,  0.17333485, ...,  0.05326335,
        -0.7195219 ,  0.20871626],
       [ 0.22563484, -0.3919892 ,  0.0870655 , ...,  0.07092983,
        -0.82888156,  0.0926896 ],
       [-0.20281596, -0.09872153, -0.11541392, ...,  0.4846839 ,
        -0.79226977,  0.25623947],
       ...,
       [ 0.15662529,  0.11859999,  0.30323905, ...,  0.42049447,
        -0.49458238, -0.12474364],
       [-0.13251446, -0.3249906 ,  0.11943641, ...,  0.1247849 ,
        -0.21506275,  0.41802666],
       [-0.20364851, -0.41124943,  0.18108313, ...,  0.15091987,
        -0.37516996,  0.36905167]], dtype=float32)

In [115]:
with open('/lustre/groups/epigenereg01/workspace/projects/vale/mlm/mpra/griesemer_2021//embeddings/dnabert2/predictions.pickle','rb') as f:
    data=pickle.load(f)
print(data['seq_names'][:5])
data['embeddings']

('10_75883134_CTT_alt', '10_75883134_CTT_ref', '11_100864071_AT_alt', '11_100864071_AT_ref', '11_100908083_TA_alt')


array([[-0.09332177,  0.07897416,  0.06044902, ...,  0.0529453 ,
         0.05022964,  0.1221395 ],
       [-0.10081972,  0.06575786,  0.05887064, ...,  0.10097329,
         0.04593064,  0.12111768],
       [-0.01290491,  0.05902864,  0.14830868, ...,  0.00596794,
         0.06165176,  0.17206357],
       ...,
       [-0.05180409,  0.06522423,  0.12311535, ...,  0.04234905,
         0.08820695,  0.03821801],
       [-0.04836041,  0.07230833,  0.08017981, ...,  0.00694123,
         0.04148955,  0.14174563],
       [-0.05082843,  0.0868283 ,  0.09633591, ...,  0.00536785,
         0.02304393,  0.1380948 ]], dtype=float32)

In [21]:
#with open(data_dir + '/variants/embeddings/dnabert2/predictions.pickle','rb') as fin:
#    data = pickle.load(fin)