In [1]:
from transformers import Trainer

from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig, BertForMaskedLM

import torch 
from transformers import AutoModelForSequenceClassification, DataCollatorWithPadding

from datasets import Dataset

import math
import itertools
from collections.abc import Mapping
import numpy as np
import pandas as pd
from collections import defaultdict
import tqdm

In [2]:
data_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/'

model_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/models/whole_genome/dnabert/6-new-12w-0/'
#model_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/models/dnabert-3utr/checkpoints/epoch_30/'

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForMaskedLM.from_pretrained(model_dir)

In [4]:
#model_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/models/dnabert-3utr/checkpoints/epoch_30/'
#model_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/models/dnabert/6-new-12w-0/'


#tokenizer = AutoTokenizer.from_pretrained(model_dir)
#tokenizer.encode('GGGGGG')

In [5]:
if not '3utr' in model_dir:
    #dna model: model trained on sequences that weren't reverse complemented for genes on negative strand
    #all default DNABERT models
    dna_model = True
else:
    dna_model = False

dna_model

True

In [6]:
MASKING=1

# Utility Functions

## Tokenization

In [7]:
nuc_dict = {"A":0,"C":1,"G":2,"T":3}

def chunkstring(string, length):
    # chunks a string into segments of length
    return (string[0+i:length+i] for i in range(0, len(string), length))

def kmers(seq, k=6):
    # splits a sequence into non-overlappnig k-mers
    return [seq[i:i + k] for i in range(0, len(seq), k) if i + k <= len(seq)]

def kmers_stride1(seq, k=6):
    # splits a sequence into overlapping k-mers
    return [seq[i:i + k] for i in range(0, len(seq)-k+1)]   

def tok_func(x): return tokenizer(" ".join(kmers_stride1(x["seq_chunked"])))

def one_hot_encode(gts, dim=5):
    result = []
    for nt in gts:
        vec = np.zeros(dim)
        vec[nuc_dict[nt]] = 1
        result.append(vec)
    return np.stack(result, axis=0)

def class_label_gts(gts):
    return np.array([nuc_dict[x] for x in gts])

In [8]:
#from transformers import  DataCollatorForLanguageModeling
#data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability = 0.15)
torch.manual_seed(0)

def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of = None):
    """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
    import torch

    # Tensorize if necessary.
    if isinstance(examples[0], (list, tuple, np.ndarray)):
        examples = [torch.tensor(e, dtype=torch.long) for e in examples]

    length_of_first = examples[0].size(0)

    # Check if padding is necessary.

    are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
    if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
        return torch.stack(examples, dim=0)

    # If yes, check if we have a `pad_token`.
    if tokenizer._pad_token is None:
        raise ValueError(
            "You are attempting to pad samples but the tokenizer you are using"
            f" ({tokenizer.__class__.__name__}) does not have a pad token."
        )

    # Creating the full tensor and filling it with our data.
    max_length = max(x.size(0) for x in examples)
    if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
        max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
    result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
    for i, example in enumerate(examples):
        if tokenizer.padding_side == "right":
            result[i, : example.shape[0]] = example
        else:
            result[i, -example.shape[0] :] = example
    return result

class DataCollatorForLanguageModelingSpan():
    def __init__(self, tokenizer, mlm, mlm_probability, span_length):
        self.tokenizer = tokenizer
        self.mlm = mlm
        self.span_length =span_length
        self.mlm_probability= mlm_probability
        self.pad_to_multiple_of = span_length

    def __call__(self, examples):
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], Mapping):
            batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            batch = {
                "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
            }

        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        if self.mlm:
            batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
                batch["input_ids"], special_tokens_mask=special_tokens_mask
            )
        else:
            labels = batch["input_ids"].clone()
            if self.tokenizer.pad_token_id is not None:
                labels[labels == self.tokenizer.pad_token_id] = -100
            batch["labels"] = labels
        return batch

    def torch_mask_tokens(self, inputs, special_tokens_mask):
        import torch

        labels = inputs.clone()
        probability_matrix = torch.full(labels.shape, self.mlm_probability*0.2)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool().numpy()
        masked_indices = np.apply_along_axis(lambda m : np.convolve(m, [1] * self.span_length, mode = 'same' ),axis = 1, arr = masked_indices).astype(bool) 
        masked_indices = torch.from_numpy(masked_indices)
        m_save = masked_indices.clone()
        
        probability_matrix = torch.full(labels.shape, self.mlm_probability*0.8) 
        probability_matrix.masked_fill_(masked_indices, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool().numpy()
        masked_indices = np.apply_along_axis(lambda m : np.convolve(m, [1] * self.span_length, mode = 'same' ),axis = 1, arr = masked_indices).astype(bool) 
        masked_indices = torch.from_numpy(masked_indices)
        m_final = masked_indices + m_save 
        labels[~m_final] = -100  # We only compute loss on masked tokens
        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        #indices_replaced = torch.bernoulli(torch.full(labels.shape, 1.0)).bool()
        #print (indices_replaced)
        inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
        #print (masked_indices)

        # 10% of the time, we replace masked input tokens with random word
        #indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        #random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        #inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

In [14]:
kmers

<function __main__.kmers(seq, k=6)>

In [96]:
seq = np.random.choice(list('ACGT'), size=510)
seq = ''.join(seq)

kmers = kmers_stride1(seq)
len(kmers)

In [97]:
kmers = kmers_stride1(seq)
len(kmers)

505

In [98]:
central_pos = len(seq)//2
central_token_idx = central_pos-5 #first token, where the central nucleotide appears
window = 5
token_idx_left,token_idx_right = central_token_idx,central_token_idx+window+6-1
pos_left,pos_right = central_pos, central_pos+window

In [99]:
kmers[token_idx_left][-1] == seq[pos_left]

True

In [100]:
kmers[token_idx_right-1][0] == seq[pos_right-1]

True

## Prediction

In [39]:
def predict_on_batch(tokenized_data, dataset, seq_idx):
    model_input_unaltered = tokenized_data['input_ids'].clone()
    label = dataset.iloc[seq_idx]['seq_chunked']
    label_len = len(label)
    if label_len < 6:
        return torch.zeros(label_len,label_len,5), None
    else:
        if MASKING:
            diag_matrix = torch.eye(tokenized_data['input_ids'].shape[1]).numpy()
            masked_indices = np.apply_along_axis(lambda m : np.convolve(m, [1] * 6, mode = 'same' ),axis = 1, arr = diag_matrix).astype(bool)
            masked_indices = torch.from_numpy(masked_indices)
            masked_indices = masked_indices[3:label_len-5-2]
            res = tokenized_data['input_ids'].expand(masked_indices.shape[0],-1).clone()
            targets_masked = res.clone().to(device)
            res[masked_indices] = 4
            targets_masked[res!=4] = -100
            #print (res[0], res.shape)
        res = res.to(device)
        with torch.no_grad():
            model_outs = model(res,labels=targets_masked)
            fin_calculation = torch.softmax(model_outs['logits'], dim=2).detach().cpu()   
        return fin_calculation, model_outs['loss']

## Translating predictions

In [40]:
def extract_prbs_from_pred(prediction, pred_pos, token_pos, label_pos, label):   
    # pred_pos = "kmer" position in tokenized sequence (incl. special tokens)
    # token_pos = position of nucleotide in kmer
    # label_pos = position of actual nucleotide in sequence
    model_pred = prediction
    prbs = [torch.sum(model_pred[pred_pos,tokendict_list[token_pos][nuc]]) for nuc in ["A","C","G","T"]]
    gt = label[label_pos] # 6-CLS, zerobased
    res = torch.tensor(prbs+[0.0])
    return res, gt

# Prepare inputs

## Prepare dataframe

In [41]:
def reverse_complement(seq):
    '''
    Take sequence reverse complement
    '''
    compl_dict = {'A':'T', 'C':'G', 'G':'C', 'T':'A'}
    compl_seq = ''.join([compl_dict.get(x,x) for x in seq])
    rev_seq = compl_seq[::-1]
    return rev_seq

In [42]:
N_folds = 10
fold = 2

In [43]:
fasta = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/fasta/Homo_sapiens_rna.fa'

dataset = defaultdict(str)

with open(fasta, 'r') as f:
    for line in f:
        if line.startswith('>'):
            seq_name = line[1:].rstrip()
        else:
            dataset[seq_name] += line.rstrip().upper()
            
dataset = pd.DataFrame(list(dataset.items()), columns=['seq_name','seq'])

folds = np.arange(N_folds).repeat(len(dataset)//N_folds+1)[:len(dataset)] 

dataset = dataset.loc[folds==fold]

strand_info = pd.read_csv(data_dir + 'UTR_coords/GRCh38_3_prime_UTR_clean-sorted.bed', sep='\t', header = None, names=['seq_name','strand'], usecols=[3,5]).set_index('seq_name').squeeze()

dataset['original_seq'] = dataset['seq'] 

dataset['seq'] = dataset.apply(lambda x: reverse_complement(x.seq) if strand_info.loc[x.seq_name]=='-' and dna_model else x.seq, axis=1) #undo reverse complement

In [44]:
dataset['seq'] = dataset.seq.apply(lambda x:''.join(np.random.choice(list(x),size=len(x),replace=False)))

In [45]:
dataset['seq_chunked'] = dataset['seq'].apply(lambda x : list(chunkstring(x, 510))) #chunk string in segments of 300

In [46]:
#output_dir = '/s/project/mll/sergey/effect_prediction/MLM/motif_predictions/split_75_25/dnabert/default/'
#
#dataset_len = 500
#
#for dataset_start in range(0,len(dataset),dataset_len):
#    df = dataset.iloc[dataset_start:dataset_start+dataset_len]
#    df[['seq_name','seq']].to_csv(output_dir + f'/seq_{dataset_start}.csv', index=None)

In [47]:
dataset = dataset.explode('seq_chunked')

In [48]:
ds = Dataset.from_pandas(dataset[['seq_chunked']])

tok_ds = ds.map(tok_func, batched=False,  num_proc=2)

rem_tok_ds = tok_ds.remove_columns('seq_chunked')

data_collator = DataCollatorForLanguageModelingSpan(tokenizer, mlm=False, mlm_probability = 0.025, span_length =6)
data_loader = torch.utils.data.DataLoader(rem_tok_ds, batch_size=1, collate_fn=data_collator, shuffle = False)

Map (num_proc=2):   0%|          | 0/8359 [00:00<?, ? examples/s]

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)


## Prepare model

In [49]:
device = torch.device("cuda")
model.to(device)
print ("Done.")

Done.


In [50]:
%%capture
computed = []

model.eval()
model.to(device)

## Prepare tokendict

In [51]:
tokendict_list = [{"A": [], "G": [], "T": [],"C": []} for x in range(6)]

for tpl in itertools.product("ACGT",repeat=6):
    encoding = tokenizer.encode("".join(tpl))
    for idx, nuc in enumerate(tpl):
        tokendict_list[idx][nuc].append(encoding[1])

# Run Inference

In [52]:
tokenized_data = next(iter(data_loader))

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [53]:
model_input_unaltered = tokenized_data['input_ids'].clone()

inputs = model_input_unaltered.clone().to(device) 
targets_masked = model_input_unaltered.clone().to(device)
    
with torch.no_grad():
    model_outs = model(inputs,labels=targets_masked)

model_pred = torch.softmax(model_outs['logits'], dim=2).cpu().numpy()
loss.append(model_outs['loss'].item())

NameError: name 'loss' is not defined

In [None]:
id_to_token = {v:k for k,v in tokenizer.vocab.items()}

In [None]:
id_seq = inputs[0].cpu().numpy()

In [147]:
seq_batch_idx = 0
seq_decoded = [id_to_token[token_id] for token_id in id_seq]
seq = dataset.seq_chunked.iloc[0]

seq_probas = []

for pos_in_seq in range(len(seq)):
    span_idx = slice(1+max(pos_in_seq-5,0),1+pos_in_seq+1)
    span = seq_decoded[span_idx]
    pos_in_token = range(len(span)-1,-1,-1)
    span_pred = model_pred[seq_batch_idx][span_idx]
    print(seq[pos_in_seq],pos_in_seq,span,pos_in_token,)
    position_probas = []
    for nuc in 'ACGT':
        mut_span = [token[:pos]+nuc+token[pos+1:] for token,pos in zip(span,pos_in_token)]
        mut_span_ids = [tokenizer.vocab[tok] for tok in mut_span]
        position_probas.append(np.mean([span_pred[idx][mut_span_ids[idx]] for idx in range(len(span))]))
    seq_probas.append(position_probas)
    if pos_in_seq==9:
        break   

G 0 ['GCAAAA'] range(0, -1, -1)
C 1 ['GCAAAA', 'CAAAAA'] range(1, -1, -1)
A 2 ['GCAAAA', 'CAAAAA', 'AAAAAT'] range(2, -1, -1)
A 3 ['GCAAAA', 'CAAAAA', 'AAAAAT', 'AAAATA'] range(3, -1, -1)
A 4 ['GCAAAA', 'CAAAAA', 'AAAAAT', 'AAAATA', 'AAATAG'] range(4, -1, -1)
A 5 ['GCAAAA', 'CAAAAA', 'AAAAAT', 'AAAATA', 'AAATAG', 'AATAGT'] range(5, -1, -1)
A 6 ['CAAAAA', 'AAAAAT', 'AAAATA', 'AAATAG', 'AATAGT', 'ATAGTG'] range(5, -1, -1)
T 7 ['AAAAAT', 'AAAATA', 'AAATAG', 'AATAGT', 'ATAGTG', 'TAGTGT'] range(5, -1, -1)
A 8 ['AAAATA', 'AAATAG', 'AATAGT', 'ATAGTG', 'TAGTGT', 'AGTGTG'] range(5, -1, -1)
G 9 ['AAATAG', 'AATAGT', 'ATAGTG', 'TAGTGT', 'AGTGTG', 'GTGTGA'] range(5, -1, -1)


In [148]:
unmap_dict = {0:'A',1:'C',2:'G',3:'T'}
is_correct=[unmap_dict[i]==seq[idx] for idx,i in enumerate(np.argmax(seq_probas,1))]
acc = np.mean(is_correct)
acc

1.0

In [161]:
seq = ''.join(np.random.choice(['A','C','G','T'],size=7))
seq

'CCTATCA'

In [163]:


max_len = 10

left = (len(seq)-max_len)//2

if left>=0:
    centered_seq = seq[left:left+max_len]
else:
    left = -left
    right = max_len-len(seq)-left
    centered_seq = 'N'*left + seq + 'N'*right

centered_seq

'NNCCTATCAN'

In [153]:
'N'*10

'NNNNNNNNNN'

In [95]:
seq[:10]

'ACCACACCAG'

In [51]:
dataset.seq_chunked.iloc[0]

'ACCACACCAGTGAACAAGAAGGAAAGGCCTTTTGGAGTGTGTCTTTCTGTGTGTTTAAAAACAGTGGGAAAATCCACACCACACTCTAAGTGGACAGCTCAGAATAATTTAGCATATTTCCTTTCACACTTAAAGTTCTTAAGATGAGACTGTGTAAATGAGAGAAAGACTTGATTCAGAGGAAAAAAATGTGTTCTCTCTGGACTCTTGCCTAGCTCACAAGGCTTTGGAGAATTGTCTTCCTAAATCGGGGCACCTGCTACAGAAGAGAATGTTTCGTTCCCTCTGGGTATGCACAGCTAAGGGGCCTCATTTCTCCCAGAGGGAGCTGCTGGCCCGCTCTAAGGGGTGCAAGAGGAAGGAGTGTGCCCAGACCAGCTGGGGGAGCATTGAAAGGACAGCTTGGGGATGGAATTTTTTTTTTTTTATCTCCATTTTCCAGTTAGGTAGAGCCAAGCTGAGGAATCCTTTTAAGATTATTAGAAAACATGCCCTGAAGAGAGTTGTTCT'

In [217]:
k = 6
predicted_prbs,gts,is_correct,loss = [],[],[],[]
#print (dataset.iloc[0]['seq_chunked'])

for no_of_index, tokenized_data in tqdm.tqdm(enumerate(data_loader)):
    #if no_of_index < 1340:
    #    continue
    label = dataset.iloc[no_of_index]['seq_chunked']
    label_len = len(label)
    #print(no_of_index, label_len)
    
    # Edge case: for a sequence less then 11 nt
    # we cannot even feed 6 mask tokens
    # so we might as well predict random
    if label_len < 20: 
        #print (no_of_index)
        for i in range(label_len):
            predicted_prbs.append(torch.tensor([0.25,0.25,0.25,0.25,0.0]))
            gts.append(label[i])
            is_correct.append(res.argmax().item() == nuc_dict.get(gt,4))
        continue

        
    model_input_unaltered = tokenized_data['input_ids'].clone()
    tokenized_data['labels'][tokenized_data['labels']==-100] = 0
    inputs = model_input_unaltered.clone()
    

    # First 5 nucleotides we infer from the first 6-mer
    if MASKING:
        inputs[:, 1:7] = 4 # we mask the first 6 6-mers
    inputs = inputs.to(device) 

    targets_masked = model_input_unaltered.clone().to(device)
    targets_masked[inputs!=4] = -100
    
    with torch.no_grad():
        model_outs = model(inputs,labels=targets_masked)
        
    model_pred = torch.softmax(model_outs['logits'], dim=2)
    loss.append(model_outs['loss'].item())
    
    for i in range(5):
        res,gt = extract_prbs_from_pred(prediction=model_pred[0],
                                        pred_pos=1, # first 6-mer (after CLS)
                                        token_pos=i, # we go thorugh first 6-mer
                                        label_pos=i,
                                        label=label)
        predicted_prbs.append(res)
        gts.append(gt)
        is_correct.append(res.argmax() == nuc_dict.get(gt,4))

    
    
    # we do a batched predict to process the rest of the sequence
    predictions,seq_loss = predict_on_batch(tokenized_data, dataset, no_of_index)
    if seq_loss is not None:
        loss.append(seq_loss.item())
    
    # For the 6th nt up to the last 5 
    # we extract probabilities similar to how the model was trained
    # hiding the 4th nt of the 3rd masked 6-mer of a span of 6 masked 6-mers
    # note that CLS makes the tokenized seq one-based
    pos = 5 # position in sequence
    for pos in range(5, label_len-5):
        model_pred = predictions[pos-5]
        res,gt = extract_prbs_from_pred(prediction=model_pred,
                                        pred_pos=pos-2, # for i-th nt, we look at (i-2)th 6-mer
                                        token_pos=3, # look at 4th nt in 6-mer
                                        label_pos=pos,
                                        label=label)

        if pos==label_len//2:
            print(res)
            
        predicted_prbs.append(res)
        gts.append(gt)
        is_correct.append(res.argmax() == nuc_dict.get(gt,4))

    # Infer the last 5 nt from the last 6-mer
    for i in range(5):
        model_pred = predictions[pos-5]
        res,gt = extract_prbs_from_pred(prediction=model_pred,
                                pred_pos=pos+1, # len - 5 + 1 = last 6-mer (1-based)
                                token_pos=i+1, # we go through last 5 of last 6-mer
                                label_pos=pos+i,
                                label=label)
        predicted_prbs.append(res)
        gts.append(gt)
        is_correct.append(res.argmax() == nuc_dict.get(gt,4))
    assert(len(gts) == torch.stack(predicted_prbs).shape[0]), "{} iter, expected len:{} vs actual len:{}".format(no_of_index,
                                                                                   len(gts), 
                                                                     torch.stack(predicted_prbs).shape[0])
    print(f'chunks:{no_of_index+1}, acc:{np.mean(is_correct):.3f}, loss:{np.mean(loss):.3f}')
    #XABCDEFGHIJKL -> XABCDE [ABCDEF BCDEFG CDEFGH DEFGHI EFGHIJ FGHIJK] GHIJKL

1it [00:06,  6.81s/it]

tensor([0.2278, 0.1440, 0.1964, 0.4317, 0.0000])
chunks:1, acc:0.294, loss:3.493


2it [00:13,  6.82s/it]

tensor([0.4078, 0.1677, 0.2584, 0.1661, 0.0000])
chunks:2, acc:0.292, loss:3.369


3it [00:20,  6.81s/it]

tensor([0.3426, 0.2092, 0.0811, 0.3671, 0.0000])
chunks:3, acc:0.287, loss:3.240


4it [00:27,  6.83s/it]

tensor([0.4620, 0.1299, 0.1765, 0.2315, 0.0000])
chunks:4, acc:0.286, loss:3.133


5it [00:34,  6.83s/it]

tensor([0.3929, 0.2165, 0.1883, 0.2024, 0.0000])
chunks:5, acc:0.280, loss:3.091


5it [00:40,  8.13s/it]


KeyboardInterrupt: 

In [214]:
def predict_on_batch(tokenized_data, dataset, seq_idx, crop_mask_left=None, crop_mask_right=None):
    model_input_unaltered = tokenized_data['input_ids'].clone()
    label = dataset.iloc[seq_idx]['seq_chunked']
    label_len = len(label)
    if label_len < 6:
        return torch.ones(label_len,label_len,5)*0.25, None
    else:
        diag_matrix = torch.eye(tokenized_data['input_ids'].shape[1]).numpy()
        masked_indices = np.apply_along_axis(lambda m : np.convolve(m, [1] * 6, mode = 'same' ),axis = 1, arr = diag_matrix).astype(bool)
        masked_indices = torch.from_numpy(masked_indices)
        masked_indices = masked_indices[3:label_len-5-2]
        if crop_mask_left and crop_mask_right:
                masked_indices = masked_indices[crop_mask_left:crop_mask_right]
        res = tokenized_data['input_ids'].expand(masked_indices.shape[0],-1).clone()
        targets_masked = res.clone().to(device)
        res[masked_indices] = 4
        targets_masked[res!=4] = -100
        #print (res[0], res.shape)
        res = res.to(device)
        with torch.no_grad():
            model_outs = model(res,labels=targets_masked)
            fin_calculation = torch.softmax(model_outs['logits'], dim=2).detach().cpu()
        return fin_calculation, model_outs['loss']

In [221]:
central_window =5

k = 6
predicted_prbs,gts,is_correct = [],[],[]

for no_of_index, tokenized_data in tqdm.tqdm(enumerate(data_loader)):

    label = dataset.iloc[no_of_index]['seq_chunked']
    L = len(label)

    if L < 20: 
        #print (no_of_index)
        for i in range(L):
            predicted_prbs.append(torch.tensor([0.25,0.25,0.25,0.25,0.25]))
            gts.append(label[i])
        continue
        
    central_pos = L//2
    central_token_idx = central_pos-5 #first token, where the central nucleotide appears
    token_idx_left,token_idx_right = central_token_idx,central_token_idx+central_window+6-1
    pos_left,pos_right = central_pos, central_pos+central_window

    predictions,seq_loss = predict_on_batch(tokenized_data, dataset, no_of_index,
                                           crop_mask_left=token_idx_left,crop_mask_right=token_idx_right)

    for pos in range(pos_left):
        predicted_prbs.append(torch.tensor([0.25,0.25,0.25,0.25,0.25]))
        gts.append(label[pos])
        
    for pos in range(pos_left,pos_right):
            model_pred = predictions[pos-pos_left]
            res,gt = extract_prbs_from_pred(prediction=model_pred,
                                            pred_pos=pos-2, # for i-th nt, we look at (i-2)th 6-mer
                                            token_pos=3, # look at 4th nt in 6-mer
                                            label_pos=pos,
                                             label=label)                
            if pos==L//2:
                print(res)
            predicted_prbs.append(res)
            gts.append(gt)
            is_correct.append(res.argmax() == nuc_dict.get(gt,4))
    
    for pos in range(pos_right,L):
        predicted_prbs.append(torch.tensor([0.25,0.25,0.25,0.25,0.25]))
        gts.append(label[pos])

    if no_of_index%100==0:
        print(f'chunks:{no_of_index+1}, acc:{np.mean(is_correct):.3f}')
    

2it [00:00,  6.54it/s]

tensor([0.2278, 0.1440, 0.1964, 0.4317, 0.0000])
chunks:1, acc:0.600
tensor([0.4078, 0.1677, 0.2584, 0.1661, 0.0000])


4it [00:00,  7.11it/s]

tensor([0.3426, 0.2092, 0.0811, 0.3671, 0.0000])
tensor([0.4620, 0.1299, 0.1765, 0.2315, 0.0000])


6it [00:00,  7.35it/s]

tensor([0.3929, 0.2165, 0.1883, 0.2024, 0.0000])
tensor([0.5619, 0.1091, 0.1305, 0.1985, 0.0000])


8it [00:01,  7.40it/s]

tensor([0.3746, 0.0728, 0.1724, 0.3802, 0.0000])
tensor([0.3086, 0.1243, 0.1752, 0.3919, 0.0000])


10it [00:01,  7.49it/s]

tensor([0.4068, 0.1356, 0.1335, 0.3241, 0.0000])
tensor([0.3109, 0.1898, 0.3150, 0.1843, 0.0000])


13it [00:01,  8.59it/s]

tensor([0.3994, 0.1980, 0.1777, 0.2249, 0.0000])
tensor([0.3148, 0.1564, 0.1682, 0.3606, 0.0000])
tensor([0.4001, 0.1890, 0.2355, 0.1754, 0.0000])


15it [00:01,  7.95it/s]

tensor([0.2574, 0.0989, 0.1350, 0.5087, 0.0000])
tensor([0.3014, 0.0859, 0.1665, 0.4462, 0.0000])


17it [00:02,  7.66it/s]

tensor([0.3239, 0.2411, 0.1073, 0.3277, 0.0000])
tensor([0.3814, 0.1694, 0.1615, 0.2876, 0.0000])


19it [00:02,  7.56it/s]

tensor([0.3899, 0.2206, 0.1039, 0.2855, 0.0000])
tensor([0.2034, 0.1224, 0.2647, 0.4095, 0.0000])


22it [00:02, 10.23it/s]

tensor([0.1946, 0.2185, 0.2926, 0.2944, 0.0000])
tensor([0.3070, 0.1969, 0.1983, 0.2978, 0.0000])
tensor([0.1478, 0.2628, 0.3095, 0.2799, 0.0000])
tensor([0.2770, 0.2449, 0.1380, 0.3401, 0.0000])


24it [00:02, 10.26it/s]

tensor([0.2930, 0.3015, 0.0708, 0.3348, 0.0000])
tensor([0.2459, 0.1081, 0.2468, 0.3992, 0.0000])


27it [00:03,  8.63it/s]

tensor([0.1350, 0.2423, 0.1678, 0.4550, 0.0000])
tensor([0.2691, 0.2508, 0.2146, 0.2655, 0.0000])


29it [00:03,  8.06it/s]

tensor([0.5871, 0.1030, 0.1449, 0.1650, 0.0000])
tensor([0.3560, 0.1648, 0.1598, 0.3194, 0.0000])


32it [00:03,  9.61it/s]

tensor([0.3616, 0.0876, 0.1929, 0.3579, 0.0000])
tensor([0.2667, 0.2315, 0.1466, 0.3552, 0.0000])
tensor([0.2394, 0.2197, 0.2845, 0.2564, 0.0000])


34it [00:04,  8.47it/s]

tensor([0.1977, 0.1769, 0.2561, 0.3693, 0.0000])
tensor([0.5236, 0.1542, 0.1440, 0.1781, 0.0000])


36it [00:04,  8.03it/s]

tensor([0.1961, 0.2368, 0.2363, 0.3307, 0.0000])
tensor([0.2626, 0.1760, 0.3978, 0.1637, 0.0000])


38it [00:04,  7.71it/s]

tensor([0.2358, 0.3202, 0.1840, 0.2600, 0.0000])
tensor([0.2969, 0.2292, 0.1631, 0.3107, 0.0000])


40it [00:04,  7.67it/s]

tensor([0.2034, 0.2486, 0.3277, 0.2203, 0.0000])
tensor([0.2872, 0.2571, 0.1601, 0.2956, 0.0000])


40it [00:05,  7.81it/s]


KeyboardInterrupt: 

In [208]:
L

510

In [181]:
pos

1

In [168]:
L

13

In [60]:
    model_input_unaltered = tokenized_data['input_ids'].clone()
    label = dataset.iloc[no_of_index]['seq_chunked']
    label_len = len(label)
    if label_len < 6:
        pass
    #    return torch.zeros(label_len,label_len,5), None
    else:
        if MASKING:
            diag_matrix = torch.eye(tokenized_data['input_ids'].shape[1]).numpy()
            masked_indices = np.apply_along_axis(lambda m : np.convolve(m, [1] * 6, mode = 'same' ),axis = 1, arr = diag_matrix).astype(bool)
            masked_indices = torch.from_numpy(masked_indices)
            #masked_indices = masked_indices[3:label_len-5-2]
            masked_indices = masked_indices[]
            res = tokenized_data['input_ids'].expand(masked_indices.shape[0],-1).clone()
            targets_masked = res.clone().to(device)
            res[masked_indices] = 4
            targets_masked[res!=4] = -100
            #print (res[0], res.shape)
        res = res.to(device)
        with torch.no_grad():
            model_outs = model(res,labels=targets_masked)
            fin_calculation = torch.softmax(model_outs['logits'], dim=2).detach().cpu()   
        #return fin_calculation, model_outs['loss']

In [102]:
res[0]

tensor([   2,    4,    4,    4,    4,    4,    4, 1130,  410, 1628, 2403, 1406,
        1515, 1952, 3698, 2490, 1753, 2902, 3401, 1301, 1094,  267, 1054,  105,
         406, 1609, 2326, 1099,  286, 1130,  410, 1626, 2394, 1369, 1367, 1357,
        1319, 1165,  550, 2186,  537, 2136,  339, 1344, 1265,  951, 3789, 2856,
        3219,  574, 2281,  920, 3665, 2358, 1226,  794, 3161,  343, 1358, 1321,
        1173,  582, 2314, 1051,   95,  365, 1446, 1674, 2587, 2142,  361, 1430,
        1610, 2329, 1110,  330, 1305, 1111,  333, 1319, 1166,  556, 2209,  629,
        2503, 1807, 3117,  166,  650, 2586, 2139,  349, 1384, 1426, 1596, 2275,
         894, 3562, 1946, 3674, 2395, 1376, 1395, 1472, 1777, 3000, 3795, 2879,
        3309,  933, 3720, 2579, 2112,  242,  956, 3810, 2938, 3545, 1877, 3400,
        1298, 1081,  215,  845, 3365, 1158,  523, 2077,  102,  393, 1558, 2121,
         277, 1095,  272, 1073,  181,  712, 2835, 3136,  242,  954, 3802, 2908,
        3425, 1397, 1480, 1809, 3126,  2

In [None]:
masked_indices[0]

In [None]:
predicted_prbs = np.array(predicted_prbs)[:,:4]

In [None]:
dataset = dataset[['seq_name','original_seq']].drop_duplicates()
dataset['seq_len'] = dataset.original_seq.apply(len)

all_preds = []

s = 0

for seq_name, original_seq, seq_len in dataset.values.tolist():
    seq_probas = predicted_prbs[s:s+seq_len,:]
    s += seq_len
    if strand_info[seq_name]=='-' and dna_model:
        seq_probas = seq_probas[::-1,[3,2,1,0]] #reverse complement probabilities s.t. probas match original_seq
    all_preds.append((seq_name,original_seq, seq_probas))

with open(output_dir + f"/predictions_{fold}.pickle", "wb") as f:
    seq_names, seqs, probs = zip(*all_preds)
    pickle.dump({'seq_names':seq_names, 'seqs':seqs, 'probs':probs, 'fasta':fasta},f)

## Statistics

In [None]:
np.mean(np.max(prbs_arr,axis=1))

In [None]:
# accuracy
np.sum(gts == np.array(["A","C","G","T"])[np.argmax(prbs_arr,axis=1)])/len(gts)

In [None]:
for nt in ["A", "C", "G", "T"]:
    nt_arr = np.array([nt]*len(gts))
    actual = np.sum(gts == nt_arr)/len(gts)
    predicted = np.sum(np.array(["A","C","G","T"])[np.argmax(prbs_arr,axis=1)] == nt_arr)/len(gts)
    print("{}: Actual {}, Predicted {}".format(nt, actual, predicted))

In [None]:
log_prbs = torch.log(torch.stack(predicted_prbs)[:,:-1])
class_labels = torch.tensor(class_label_gts(gts))

In [None]:
torch.nn.functional.nll_loss(log_prbs, class_labels)

# Make data fit metrics handler

In [None]:
#out_path = "outputs/gpar_bertadn/"

In [None]:
# get targets
targets = torch.tensor(class_label_gts(gts))
stacked_prbs = torch.stack(predicted_prbs)

In [None]:
# compute cross entropy, it's already as probability so just nll
ce = torch.nn.functional.nll_loss(stacked_prbs, targets, reduction="none") #cross_entropy(prbs, targets)

#print(ce)

# save
torch.save(stacked_prbs,  out_path+"masked_logits.pt") # no logits, so use prbs
torch.save(torch.argmax(stacked_prbs, dim=1),  out_path+"masked_preds.pt")
torch.save(stacked_prbs,  out_path+"prbs.pt")
torch.save(ce, out_path+"ce.pt")

# save targets
torch.save(targets, out_path+"masked_targets.pt")

# save rest as placeholders (zeros of same length)
torch.save(torch.zeros(len(stacked_prbs)), out_path+"masked_motifs.pt")