In [32]:
import torch 
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np
import pandas as pd
import pickle
from collections import defaultdict
from torch.utils.data import DataLoader, IterableDataset
from torch.nn.utils.rnn import pad_sequence

import itertools

import os
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [33]:
N_FOLDS = 10
MAX_TOK_LEN = 1024
BATCH_SIZE = 64

In [34]:
data_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/'

In [35]:
model_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/MLM/nucleotide-transform/nucleotide-transformer-v2-500m-multi-species'
model_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/MLM/nucleotide-transform/nucleotide-transformer-v2-250m-multi-species'

model_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/models/ntrans-v2-250m-3utr/checkpoints/epoch_31/'

In [36]:
if not '3utr' in model_dir:
    #dna model: model trained on sequences that weren't reverse complemented for genes on negative strand
    #all NT models from InstaDeep
    #the sequences transcribed from the negative strand will be reverse complemented before inference
    #after inference, the probabilities will be reverse complemented again to match the original sequence
    reverse_seq_neg_strand = True
else:
    reverse_seq_neg_strand = False

print(f'Reverse sequences on negative strand before inference: {reverse_seq_neg_strand}')

Reverse sequences on negative strand before inference: False


In [37]:
device = torch.device("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_dir,trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_dir,trust_remote_code=True).to(device);

In [7]:
def get_chunks(seq_tokens):
    '''
    Chunk the given token sequence into chunks of MAX_TOK_LEN
    The input sequences shouldn't contain special tokens
    The last chunk is padded with the previous chunk if it's shorter than MAX_TOK_LEN
    '''
    if tokenizer.eos_token_id:
        #in the original InstaDeep models, the cls token wasn't present
        chunk_len = MAX_TOK_LEN-2 #2 special tokens to be added 
    else:
        chunk_len = MAX_TOK_LEN-1 #only cls token
    chunks = [seq_tokens[start:start+chunk_len] for start in range(0,len(seq_tokens),chunk_len)]
    assert [x for y in chunks for x in y]==seq_tokens
    if len(chunks)>1:
        left_shift = min(chunk_len-len(chunks[-1]), len(chunks[-2])) #overlap length for the last chunk and the previous one
        if left_shift>0:
            pad_seq = chunks[-2][-left_shift:]
            chunks[-1] = pad_seq + chunks[-1]
    else:
        left_shift = 0
    if tokenizer.eos_token_id:
        chunks = [[tokenizer.cls_token_id, *chunk, tokenizer.eos_token_id] for chunk in chunks]
        assert [x for y in chunks[:-1] for x in y[1:-1]]+[x for x in  chunks[-1][1+left_shift:-1]]==seq_tokens
    else:
        chunks = [[tokenizer.cls_token_id, *chunk] for chunk in chunks]
        assert [x for y in chunks[:-1] for x in y[1:]]+[x for x in  chunks[-1][1+left_shift:]]==seq_tokens
    #left_shift only makes sense for the last chunk, for the other chunks it's 0
    res = [(chunk,0) if chunk_idx!=len(chunks)-1 else (chunk,left_shift) for chunk_idx, chunk in enumerate(chunks)]
    return res

In [8]:
def mask_sequence(seq_tokens, left_shift=0):
    '''
    Consecutively mask tokens in the sequence and yield each masked position
    Don't mask special tokens and the first left_shift sequence tokens
    '''    
    for mask_pos in range(1+left_shift,len(seq_tokens)):
        if seq_tokens[mask_pos] in (tokenizer.eos_token_id,tokenizer.pad_token_id):
            break
        masked_seq = seq_tokens.clone()
        masked_seq[mask_pos] = tokenizer.mask_token_id
        yield mask_pos, masked_seq

In [9]:
class SeqDataset(IterableDataset):
    
    def __init__(self, seq_df):
        
        self.seq_df = seq_df
        self.start = 0
        self.end = len(self.seq_df)
        
    def __iter__(self):
        
        for seq_idx in range(self.start, self.end):
            
            seq_info = self.seq_df.iloc[seq_idx]
            chunk, left_shift = seq_info.seq
            
            gt_tokens = torch.LongTensor(chunk)
            
            for masked_pos, masked_tokens in mask_sequence(gt_tokens, left_shift):
                #consecutively mask each token in the sequence
                assert masked_tokens[masked_pos] == tokenizer.mask_token_id
                yield seq_info.seq_name, gt_tokens, masked_pos, masked_tokens

def worker_init_fn(worker_id):
     worker_info = torch.utils.data.get_worker_info()
     dataset = worker_info.dataset  # the dataset copy in this worker process
     overall_start = dataset.start
     overall_end = dataset.end
     # configure the dataset to only process the split workload
     per_worker = int(np.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
     worker_id = worker_info.id
     dataset.start = overall_start + worker_id * per_worker
     dataset.end = min(dataset.start + per_worker, overall_end)

In [10]:
def predict_on_batch(masked_tokens_batch):

    targets_masked = masked_tokens_batch.clone()
    targets_masked[targets_masked!=tokenizer.mask_token_id] = -100
    attention_mask = masked_tokens_batch!= tokenizer.pad_token_id   
    
    with torch.no_grad():
        torch_outs = model(
        masked_tokens_batch.to(device),
        labels = targets_masked.to(device),
        attention_mask=attention_mask.to(device),
        encoder_attention_mask=attention_mask.to(device),
        output_hidden_states=False)
    
    logits = torch_outs["logits"] #max_tokenized_length x (max_tokenized_length+1) x N_tokens
    
    probas_batch = F.softmax(logits, dim=-1).cpu().numpy()
    
    loss = torch_outs["loss"].item()
    
    return probas_batch, loss

In [11]:
def reverse_complement(seq):
    '''
    Take sequence reverse complement
    '''
    compl_dict = {'A':'T', 'C':'G', 'G':'C', 'T':'A'}
    compl_seq = ''.join([compl_dict.get(x,x) for x in seq])
    rev_seq = compl_seq[::-1]
    return rev_seq

In [12]:
tokendict_list = [{"A": [], "G": [], "T": [],"C": []} for x in range(6)]

for tpl in itertools.product("ACGT",repeat=6):
    encoding = tokenizer.encode("".join(tpl))
    for idx, nuc in enumerate(tpl):
        tokendict_list[idx][nuc].append(encoding[1]) #token indices for idx position in 6-mer and letter nuc

In [13]:
fasta = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/fasta/Homo_sapiens_rna.fa'

seq_df = defaultdict(str)

with open(fasta, 'r') as f:
    for line in f:
        if line.startswith('>'):
            seq_name = line[1:].rstrip()
        else:
            seq_df[seq_name] += line.rstrip().upper()
            
seq_df = pd.DataFrame(list(seq_df.items()), columns=['seq_name','seq'])

In [14]:
fold = 0

In [15]:
folds = np.arange(N_FOLDS).repeat(len(seq_df)//N_FOLDS+1)[:len(seq_df)] #split into folds 

seq_df = seq_df.loc[folds==fold] #get required fold

print(f'Fold {fold}: {len(seq_df)} sequences')

#reverse complement on the negative strand if reverse_seq_neg_strand=True
strand_info = pd.read_csv(data_dir + 'UTR_coords/GRCh38_3_prime_UTR_clean-sorted.bed', sep='\t', header = None, names=['seq_name','strand'], usecols=[3,5]).set_index('seq_name').squeeze()
seq_df.seq = seq_df.apply(lambda x: reverse_complement(x.seq) if strand_info.loc[x.seq_name]=='-' 
                          and reverse_seq_neg_strand else x.seq, axis=1) #undo reverse complement

original_seqs = seq_df.set_index('seq_name').seq #sequences before tokenization

seq_df['seq'] = seq_df.seq.apply(lambda x: tokenizer(x,add_special_tokens=False)['input_ids']) #tokenize sequences
seq_df['seq'] = seq_df.seq.apply(lambda x:get_chunks(x)) #chunk sequences

seq_df = seq_df.explode('seq') #each chunk gets a single row

Token indices sequence length is longer than the specified maximum sequence length for this model (1078 > 1024). Running this sequence through the model will result in indexing errors


Fold 0: 1814 sequences


In [16]:
#N_tokens = seq_df.seq.apply(lambda x:len(x[0])-2-x[1]).sum() #minus special tokens minus overlap for last chunk

In [17]:
def collate_fn(batch):
    '''
    Collate tokenized sequences based on the maximal sequence length in the batch
    '''
    seq_names_batch, gt_tokens_batch, masked_pos_batch, masked_tokens_batch = zip(*batch)
    masked_tokens_batch = pad_sequence(masked_tokens_batch, batch_first=True, padding_value=tokenizer.pad_token_id)
    gt_tokens_batch = pad_sequence(gt_tokens_batch, batch_first=True, padding_value=tokenizer.pad_token_id)
    return seq_names_batch, gt_tokens_batch, masked_pos_batch, masked_tokens_batch

In [18]:
dataloader = DataLoader(SeqDataset(seq_df), batch_size=BATCH_SIZE, 
                        shuffle=False, collate_fn=collate_fn, 
                        num_workers=1, worker_init_fn=worker_init_fn)

In [19]:
nuc_dict = {"A":0,"C":1,"G":2,"T":3} #for accuracy

all_probas = defaultdict(list) #probas for all masked tokens 
verif_seqs = defaultdict(str) #reconstruct sequences from mask tokens and make sure that they match the original sequences

all_losses, is_correct = [], [] #to compute loss and accuracy
prev_seq_name = None #name of the previous sequence

pbar = tqdm(total=len(original_seqs))

for seq_names_batch, gt_tokens_batch, masked_pos_batch, masked_tokens_batch in dataloader:

    probas_batch, loss_batch = predict_on_batch(masked_tokens_batch)
    #probas_batch, loss_batch = np.zeros((len(seq_names_batch),1024,4108)), 0 #placeholder for testing
    
    all_losses.append(loss_batch)
    
    for seq_name, gt_tokens, masked_pos, seq_probas in zip(seq_names_batch, gt_tokens_batch, masked_pos_batch, probas_batch):
        gt_token = tokenizer.id_to_token(gt_tokens[masked_pos].item()) #ground truth masked token
        for idx in range(len(gt_token)):
            #loop over all positions of the masked token
            position_probas = [] #probabilities for all bases at given position
            for nuc in 'ACGT':
                position_probas.append(seq_probas[masked_pos][tokendict_list[idx][nuc]].sum()) #sum over all tokens that have given letter at given position
            all_probas[seq_name].append(position_probas)
        if seq_name!=prev_seq_name:
            #processing of prev_seq_name is completed
            if len(verif_seqs[prev_seq_name])>0:
                is_correct.extend([nuc_dict.get(base,4)==gt_idx for base, gt_idx in zip(verif_seqs[prev_seq_name],np.argmax(all_probas[prev_seq_name],axis=1))])
                print(f'Sequence {prev_seq_name} processed ({len(verif_seqs)-1}/{len(original_seqs)}), loss: {np.mean(all_losses):.3}, acc:{np.mean(is_correct):.3}')
                assert verif_seqs[prev_seq_name]==original_seqs.loc[prev_seq_name] #compare reconstruction from the masked token with the original sequence
                pbar.update(1)
            prev_seq_name = seq_name
        verif_seqs[seq_name] += gt_token      

assert verif_seqs[seq_name]==original_seqs.loc[seq_name]

  0%|          | 1/1814 [00:13<6:42:47, 13.33s/it]

Sequence ENST00000641515.2_utr3_2_0_chr1_70009_f processed (1/1814), loss: 22.7, acc:0.375
Sequence ENST00000616016.5_utr3_13_0_chr1_944154_f processed (2/1814), loss: 22.7, acc:0.379


  0%|          | 3/1814 [00:14<1:56:07,  3.85s/it]

Sequence ENST00000327044.7_utr3_18_0_chr1_944203_r processed (3/1814), loss: 22.3, acc:0.378


  0%|          | 4/1814 [00:16<1:40:33,  3.33s/it]

Sequence ENST00000338591.8_utr3_11_0_chr1_965192_f processed (4/1814), loss: 22.3, acc:0.38


  0%|          | 5/1814 [00:19<1:32:02,  3.05s/it]

Sequence ENST00000379410.8_utr3_15_0_chr1_974576_f processed (5/1814), loss: 22.3, acc:0.39


  0%|          | 6/1814 [00:21<1:20:56,  2.69s/it]

Sequence ENST00000433179.4_utr3_3_0_chr1_975198_r processed (6/1814), loss: 22.4, acc:0.396


  0%|          | 7/1814 [00:23<1:14:48,  2.48s/it]

Sequence ENST00000304952.11_utr3_3_0_chr1_998964_r processed (7/1814), loss: 22.3, acc:0.397
Sequence ENST00000649529.1_utr3_1_0_chr1_1014479_f processed (8/1814), loss: 22.3, acc:0.398


  0%|          | 9/1814 [00:25<55:07,  1.83s/it]  

Sequence ENST00000379370.7_utr3_35_0_chr1_1054982_f processed (9/1814), loss: 22.3, acc:0.4


  1%|          | 10/1814 [00:27<56:45,  1.89s/it]

Sequence ENST00000453464.3_utr3_1_0_chr1_1070967_r processed (10/1814), loss: 22.1, acc:0.416


  1%|          | 11/1814 [00:31<1:13:31,  2.45s/it]

Sequence ENST00000421241.7_utr3_9_0_chr1_1081823_r processed (11/1814), loss: 22.1, acc:0.41
Sequence ENST00000379289.6_utr3_15_0_chr1_1197848_f processed (12/1814), loss: 22.1, acc:0.411
Sequence ENST00000379268.7_utr3_4_0_chr1_1203508_r processed (13/1814), loss: 22.1, acc:0.413


  1%|          | 14/1814 [00:34<52:25,  1.75s/it]  

Sequence ENST00000379236.4_utr3_6_0_chr1_1211340_r processed (14/1814), loss: 21.8, acc:0.414
Sequence ENST00000360001.12_utr3_6_0_chr1_1216931_r processed (15/1814), loss: 21.8, acc:0.413


  1%|          | 16/1814 [00:45<1:26:35,  2.89s/it]

Sequence ENST00000379198.5_utr3_0_0_chr1_1233269_f processed (16/1814), loss: 21.6, acc:0.403
Sequence ENST00000330388.2_utr3_7_0_chr1_1242453_r processed (17/1814), loss: 21.6, acc:0.403


  1%|          | 18/1814 [00:49<1:21:55,  2.74s/it]

Sequence ENST00000349431.11_utr3_6_0_chr1_1253912_r processed (18/1814), loss: 21.5, acc:0.4
Sequence ENST00000379116.10_utr3_17_0_chr1_1291611_f processed (19/1814), loss: 21.5, acc:0.4


  1%|          | 20/1814 [00:54<1:17:04,  2.58s/it]

Sequence ENST00000354700.10_utr3_23_0_chr1_1292391_r processed (20/1814), loss: 21.6, acc:0.4
Sequence ENST00000379031.10_utr3_7_0_chr1_1311380_f processed (21/1814), loss: 21.6, acc:0.399
Sequence ENST00000435064.6_utr3_16_0_chr1_1311600_r processed (22/1814), loss: 21.6, acc:0.399


  1%|▏         | 23/1814 [00:58<1:03:25,  2.12s/it]

Sequence ENST00000343938.9_utr3_2_0_chr1_1327764_f processed (23/1814), loss: 21.8, acc:0.398


  1%|▏         | 24/1814 [01:00<1:01:15,  2.05s/it]

Sequence ENST00000339381.6_utr3_5_0_chr1_1334465_f processed (24/1814), loss: 21.8, acc:0.398


  1%|▏         | 25/1814 [01:02<59:39,  2.00s/it]  

Sequence ENST00000378888.10_utr3_14_0_chr1_1335278_r processed (25/1814), loss: 21.8, acc:0.398


  1%|▏         | 26/1814 [01:04<58:13,  1.95s/it]

Sequence ENST00000309212.11_utr3_9_0_chr1_1352691_r processed (26/1814), loss: 21.9, acc:0.397


  1%|▏         | 27/1814 [01:06<1:04:25,  2.16s/it]

Sequence ENST00000338338.10_utr3_3_0_chr1_1373736_r processed (27/1814), loss: 21.9, acc:0.397


  2%|▏         | 28/1814 [01:12<1:32:14,  3.10s/it]

Sequence ENST00000400809.8_utr3_10_0_chr1_1385711_r processed (28/1814), loss: 21.9, acc:0.392
Sequence ENST00000344843.12_utr3_3_0_chr1_1401909_r processed (29/1814), loss: 21.9, acc:0.392


  2%|▏         | 30/1814 [01:17<1:24:01,  2.83s/it]

Sequence ENST00000537107.6_utr3_3_0_chr1_1418420_r processed (30/1814), loss: 21.8, acc:0.391


  2%|▏         | 31/1814 [01:33<2:53:56,  5.85s/it]

Sequence ENST00000378821.4_utr3_1_0_chr1_1427788_f processed (31/1814), loss: 21.7, acc:0.411


  2%|▏         | 32/1814 [01:58<5:15:34, 10.63s/it]

Sequence ENST00000476993.2_utr3_2_0_chr1_1439788_f processed (32/1814), loss: 21.7, acc:0.406


  2%|▏         | 33/1814 [02:06<4:47:44,  9.69s/it]

Sequence ENST00000378785.7_utr3_11_0_chr1_1468531_f processed (33/1814), loss: 21.7, acc:0.413


  2%|▏         | 34/1814 [02:17<5:01:55, 10.18s/it]

Sequence ENST00000673477.1_utr3_15_0_chr1_1495818_f processed (34/1814), loss: 21.7, acc:0.417


  2%|▏         | 36/1814 [02:18<2:43:52,  5.53s/it]

Sequence ENST00000378756.8_utr3_15_0_chr1_1534073_f processed (35/1814), loss: 21.6, acc:0.417
Sequence ENST00000378733.9_utr3_3_0_chr1_1534778_r processed (36/1814), loss: 21.6, acc:0.417


  2%|▏         | 37/1814 [02:21<2:17:28,  4.64s/it]

Sequence ENST00000291386.4_utr3_4_0_chr1_1541673_r processed (37/1814), loss: 21.6, acc:0.417


  2%|▏         | 38/1814 [02:26<2:20:23,  4.74s/it]

Sequence ENST00000422725.4_utr3_0_0_chr1_1598012_r processed (38/1814), loss: 21.6, acc:0.416
Sequence ENST00000355826.10_utr3_19_0_chr1_1630531_f processed (39/1814), loss: 21.6, acc:0.416
Sequence ENST00000356026.10_utr3_7_0_chr1_1634626_f processed (40/1814), loss: 21.6, acc:0.416


  2%|▏         | 41/1814 [02:35<1:52:26,  3.80s/it]

Sequence ENST00000341832.11_utr3_19_0_chr1_1635225_r processed (41/1814), loss: 21.6, acc:0.416


  2%|▏         | 42/1814 [03:20<6:08:26, 12.48s/it]

Sequence ENST00000617444.5_utr3_9_0_chr1_1661478_r processed (42/1814), loss: 21.6, acc:0.414


  2%|▏         | 43/1814 [03:23<5:05:31, 10.35s/it]

Sequence ENST00000404249.8_utr3_19_0_chr1_1702379_r processed (43/1814), loss: 21.6, acc:0.413


  2%|▏         | 44/1814 [03:30<4:37:32,  9.41s/it]

Sequence ENST00000341426.9_utr3_11_0_chr1_1751232_r processed (44/1814), loss: 21.6, acc:0.411
Sequence ENST00000307786.8_utr3_5_0_chr1_1917194_f processed (45/1814), loss: 21.6, acc:0.411


  3%|▎         | 46/1814 [03:31<2:48:05,  5.70s/it]

Sequence ENST00000310991.8_utr3_4_0_chr1_1917591_r processed (46/1814), loss: 21.6, acc:0.411
Sequence ENST00000682832.2_utr3_38_0_chr1_1921957_r processed (47/1814), loss: 21.6, acc:0.411


  3%|▎         | 48/1814 [03:38<2:23:56,  4.89s/it]

Sequence ENST00000378585.7_utr3_8_0_chr1_2030283_f processed (48/1814), loss: 21.6, acc:0.411
Sequence ENST00000378567.8_utr3_17_0_chr1_2185010_f processed (49/1814), loss: 21.6, acc:0.41
Sequence ENST00000378546.9_utr3_3_0_chr1_2189548_r processed (50/1814), loss: 21.6, acc:0.41


  3%|▎         | 51/1814 [04:13<3:52:17,  7.91s/it]

Sequence ENST00000378536.5_utr3_6_0_chr1_2306766_f processed (51/1814), loss: 21.5, acc:0.409
Sequence ENST00000378531.8_utr3_13_0_chr1_2321253_r processed (52/1814), loss: 21.5, acc:0.409


  3%|▎         | 53/1814 [04:27<3:42:55,  7.60s/it]

Sequence ENST00000605895.6_utr3_6_0_chr1_2403125_f processed (53/1814), loss: 21.4, acc:0.406


  3%|▎         | 54/1814 [04:34<3:39:28,  7.48s/it]

Sequence ENST00000447513.7_utr3_5_0_chr1_2403974_r processed (54/1814), loss: 21.4, acc:0.404


  3%|▎         | 55/1814 [04:35<3:03:13,  6.25s/it]

Sequence ENST00000378486.8_utr3_21_0_chr1_2505214_f processed (55/1814), loss: 21.4, acc:0.404
Sequence ENST00000378466.9_utr3_18_0_chr1_2508537_r processed (56/1814), loss: 21.4, acc:0.404


  3%|▎         | 57/1814 [04:36<2:04:07,  4.24s/it]

Sequence ENST00000378453.4_utr3_2_0_chr1_2528745_r processed (57/1814), loss: 21.4, acc:0.403


  3%|▎         | 58/1814 [04:40<2:01:40,  4.16s/it]

Sequence ENST00000355716.5_utr3_7_0_chr1_2563274_f processed (58/1814), loss: 21.4, acc:0.403


  3%|▎         | 60/1814 [04:58<2:37:50,  5.40s/it]

Sequence ENST00000419916.8_utr3_6_0_chr1_2589428_f processed (59/1814), loss: 21.5, acc:0.401
Sequence ENST00000378412.8_utr3_23_0_chr1_2590639_r processed (60/1814), loss: 21.5, acc:0.401


  3%|▎         | 61/1814 [05:54<9:06:25, 18.70s/it]

Sequence ENST00000401095.9_utr3_8_0_chr1_2636986_r processed (61/1814), loss: 21.6, acc:0.398
Sequence ENST00000378404.4_utr3_0_0_chr1_3022821_f processed (62/1814), loss: 21.6, acc:0.397


  3%|▎         | 63/1814 [06:57<11:43:21, 24.10s/it]

Sequence ENST00000270722.10_utr3_16_0_chr1_3433812_f processed (63/1814), loss: 21.5, acc:0.395


  4%|▎         | 64/1814 [07:02<9:33:19, 19.66s/it] 

Sequence ENST00000378378.9_utr3_14_0_chr1_3480588_f processed (64/1814), loss: 21.5, acc:0.395


  4%|▎         | 65/1814 [07:17<9:01:11, 18.57s/it]

Sequence ENST00000356575.9_utr3_36_0_chr1_3487951_r processed (65/1814), loss: 21.6, acc:0.394


  4%|▎         | 66/1814 [07:23<7:21:58, 15.17s/it]

Sequence ENST00000378344.7_utr3_4_0_chr1_3628604_f processed (66/1814), loss: 21.6, acc:0.395


  4%|▎         | 67/1814 [07:29<6:10:22, 12.72s/it]

Sequence ENST00000270708.12_utr3_11_0_chr1_3630770_r processed (67/1814), loss: 21.6, acc:0.395


  4%|▎         | 68/1814 [07:54<7:52:38, 16.24s/it]

Sequence ENST00000378295.9_utr3_13_0_chr1_3733080_f processed (68/1814), loss: 21.6, acc:0.393
Sequence ENST00000294600.7_utr3_11_0_chr1_3771524_f processed (69/1814), loss: 21.6, acc:0.394
Sequence ENST00000642557.4_utr3_3_0_chr1_3775922_f processed (70/1814), loss: 21.6, acc:0.394


  4%|▍         | 71/1814 [08:11<5:02:17, 10.41s/it]

Sequence ENST00000378251.3_utr3_6_0_chr1_3778559_r processed (71/1814), loss: 21.6, acc:0.394


  4%|▍         | 72/1814 [08:45<7:20:36, 15.18s/it]

Sequence ENST00000378230.8_utr3_21_0_chr1_3812086_r processed (72/1814), loss: 21.5, acc:0.399


KeyboardInterrupt: 

In [None]:
seq_names = list(all_probas.keys())
probs = [np.array(x) for x in all_probas.values()]
seqs = original_seqs.loc[seq_names].values.tolist()

if reverse_seq_neg_strand:
    probs = [x[::-1,[3,2,1,0]] if strand_info.loc[seq_name]=='-' else x for x, seq_name in zip(probs,seq_names)]
    seqs = [reverse_complement(x) if strand_info.loc[seq_name]=='-' else x for x, seq_name in zip(seqs,seq_names)]

In [None]:
#with open(data_dir + f'motif_predictions/split_75_25/ntrans/NT-MS-v2-500M_{fold}.pickle', 'wb') as f:
#    pickle.dump({'seq_names':seq_names,'seqs':seqs, 'probs':probs, 'fasta':fasta},f)

print('Done')