In [48]:
%pprint

Pretty printing has been turned OFF


In [1]:
from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)

In [60]:
import numpy as np
import pandas as pd
import random
import math
import torch as torch


In [3]:
sequence_file = 'train_sequences.parquet'

In [4]:
sequences_df = pd.read_parquet(sequence_file)

In [5]:
sequences_df.shape

(806573, 2)

In [6]:
sequences_df.head(2)

Unnamed: 0,sequence_id,sequence
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...


In [7]:
# sequences_df.sample(10)['sequence'].tolist()

### BPE tokenizer

In [181]:
tokenizer = Tokenizer(models.BPE())

In [182]:
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

In [183]:
def get_training_corpus():
    for i in range(100):
        yield sequences_df.sample(1000)['sequence'].tolist()

In [184]:
# next(iter(get_training_corpus))

In [185]:
special_tokens = ["[PAD]", "[CLS]", "[SEP]", "[MASK]"]

In [186]:
# trainer = trainers.BpeTrainer(vocab_size=1000, special_tokens=["<|endoftext|>"])

trainer = trainers.BpeTrainer(vocab_size=256, max_token_length=10, special_tokens=special_tokens, show_preogress=True)

# trainer = trainers.BpeTrainer(max_token_length=10, special_tokens=special_tokens, show_preogress=True)

In [187]:
tokenizer.train_from_iterator(get_training_corpus(), trainer=trainer)

In [188]:
tokenizer.save('tokenizer.json')

In [189]:
tokenizer.encode("[PAD][CLS][SEP][MASK]").ids

[0, 1, 2, 3]

In [190]:
tokenizer.token_to_id("[MASK]")

3

In [191]:
seq = 'GGGAACGACUCGAGUAGAGUCGAAAAUUGUGUUAACAUCGCACUCGGUAGCUAAUUUAAGUGCUCCUACGCUUGUCCCGCAGGAGAAUUAUAGUAGCAUUAGAUUUGCUAGUGUUUAUAGUGUGCUGAUAGCGAGUGACUUCGGUCACUCGCUAUCAAAAGAAACAACAACAACAAC'

In [192]:
encoding = tokenizer.encode(seq)
print(encoding.tokens)

['GGGAACGACUCGAGUAGAGUCGAAAA', 'UUG', 'UGUU', 'AA', 'CAUCG', 'CAC', 'UCGG', 'UAG', 'CUAA', 'UU', 'UAAG', 'UG', 'CUC', 'CUA', 'CG', 'CUUG', 'UC', 'CCG', 'CAGG', 'AGAA', 'UU', 'AUAG', 'UAG', 'CAUU', 'AG', 'AUU', 'UG', 'CUAG', 'UGUU', 'UAUAG', 'UGUG', 'CUG', 'AUAG', 'CGAG', 'UGAC', 'UUCGG', 'UCAC', 'UCG', 'CUA', 'UCAAAAGAAACAACAACAACAAC']


In [193]:
len(encoding.tokens)

40

In [194]:
tokenizer.get_vocab_size()

256

In [180]:
len('GGGAACGACUCGAGUAGAGUCGAAAAAA')

28

In [195]:
tokenizer.get_vocab()

{'CAAAAGAAACAACAACAACAAC': 58, 'CUUG': 109, 'AUG': 30, 'UGGUG': 202, 'UAUAC': 252, 'CGCG': 113, 'UAAAA': 90, 'UGAAG': 223, 'UUCG': 25, 'AGG': 50, 'UAUC': 148, 'AUC': 48, 'UUAC': 104, 'CCAAAAGAAACAACAACAACAAC': 175, 'UAGGG': 249, '[PAD]': 0, 'CACG': 100, '[MASK]': 3, 'CUUUU': 248, 'UGUUG': 170, 'AUGG': 72, 'CUAC': 98, 'UAUAA': 182, 'CUUCG': 155, 'CAACAA': 18, 'CUUGG': 185, 'CGUC': 232, 'GAAACAACAACAACAAC': 44, 'UUUU': 67, 'CCUUCGGG': 213, 'CAUA': 172, 'CCUG': 108, 'CCAAG': 243, 'AG': 10, 'GAAACG': 201, 'UUAG': 136, 'UCCUAAGUCAA': 199, 'GGGAACGACUCGAGUAGAGUCGAAAAAAAA': 169, 'ACGAA': 218, 'GAAG': 115, 'ACGG': 120, 'UAUU': 118, 'CGGG': 103, 'CUAA': 101, 'UUUG': 78, 'CAUC': 119, 'AGAAACAACAACAACAAC': 131, 'ACAAAAGAAACAACAACAACAAC': 196, 'CAACAACAACAAC': 39, 'CUUAA': 227, 'UCAC': 102, 'AGGAG': 206, 'UCC': 69, 'CAGCC': 229, 'AAAAGAAACAACAACAACAAC': 62, 'UGUUCG': 189, 'CUU': 47, 'UGAC': 126, 'CUCG': 77, 'CGG': 55, 'UUUUG': 220, 'UAAAAGAAACAACAACAACAAC': 132, 'UCUUCGG': 173, 'AGAA': 64, 'CAGGG'

In [31]:
# encoding.attention_mask

In [49]:
encoding.ids

[621, 318, 219, 52, 123, 21, 100, 10, 56, 366, 280, 108, 399, 92, 64, 10, 104, 21, 94, 11, 505, 111, 70, 241, 73, 46, 104, 98, 127, 59, 102, 17, 67, 142]

In [50]:
from transformers import DataCollatorForLanguageModeling

In [85]:
tokenizer.enable_padding()
tokenizer.mask_token = 3

In [57]:
# tokenizer.pad_token

In [59]:
# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [66]:
mask = torch.zeros(100, dtype=torch.bool)
mask[:len(encoding.ids)] = True

In [69]:
encoding.ids

[621, 318, 219, 52, 123, 21, 100, 10, 56, 366, 280, 108, 399, 92, 64, 10, 104, 21, 94, 11, 505, 111, 70, 241, 73, 46, 104, 98, 127, 59, 102, 17, 67, 142]

In [70]:
encoding.pad(length=100)
encoding.ids

[621, 318, 219, 52, 123, 21, 100, 10, 56, 366, 280, 108, 399, 92, 64, 10, 104, 21, 94, 11, 505, 111, 70, 241, 73, 46, 104, 98, 127, 59, 102, 17, 67, 142, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [71]:
encoding.tokens

['GGGAACGACUCGAGUAGAGUCGAAAAUUG', 'UGUUAA', 'CAUCG', 'CAC', 'UCGG', 'UAG', 'CUAA', 'UU', 'UAAG', 'UGCUC', 'CUACG', 'CUUG', 'UCCCG', 'CAGG', 'AGAA', 'UU', 'AUAG', 'UAG', 'CAUU', 'AG', 'AUUUG', 'CUAG', 'UGUU', 'UAUAG', 'UGUG', 'CUG', 'AUAG', 'CGAG', 'UGAC', 'UUCGG', 'UCAC', 'UCG', 'CUA', 'UCAAAAGAAACAACAACAACAAC', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']

In [67]:
mask

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [74]:
rand = torch.rand(mask.sum())
mask_arr = rand < 0.15

In [75]:
rand, mask_arr

(tensor([0.1014, 0.1817, 0.6547, 0.0950, 0.5827, 0.3525, 0.3917, 0.4392, 0.8931,
        0.0124, 0.0474, 0.5720, 0.7494, 0.6528, 0.8663, 0.2906, 0.6303, 0.6690,
        0.9298, 0.0560, 0.8750, 0.4932, 0.6929, 0.9576, 0.1168, 0.3224, 0.6660,
        0.1057, 0.7689, 0.5415, 0.1716, 0.3628, 0.0567, 0.9902]), tensor([ True, False, False,  True, False, False, False, False, False,  True,
         True, False, False, False, False, False, False, False, False,  True,
        False, False, False, False,  True, False, False,  True, False, False,
        False, False,  True, False]))

In [79]:
selection = torch.flatten((mask_arr).nonzero()).tolist()
selection

[0, 3, 9, 10, 19, 24, 27, 32]

In [107]:
tok_seq = torch.IntTensor(encoding.ids)
tok_seq

tensor([621, 318, 219,  52, 123,  21, 100,  10,  56, 366, 280, 108, 399,  92,
         64,  10, 104,  21,  94,  11, 505, 111,  70, 241,  73,  46, 104,  98,
        127,  59, 102,  17,  67, 142,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0], dtype=torch.int32)

In [108]:
mlm = tok_seq.detach().clone()
mlm[selection] = tokenizer.mask_token

In [109]:
mlm

tensor([  3, 318, 219,   3, 123,  21, 100,  10,  56,   3,   3, 108, 399,  92,
         64,  10, 104,  21,  94,   3, 505, 111,  70, 241,   3,  46, 104,   3,
        127,  59, 102,  17,   3, 142,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0], dtype=torch.int32)

In [110]:
mlm_target = tok_seq.masked_fill(~(mlm == tokenizer.mask_token), 0)

In [111]:
mlm_target

tensor([621,   0,   0,  52,   0,   0,   0,   0,   0, 366, 280,   0,   0,   0,
          0,   0,   0,   0,   0,  11,   0,   0,   0,   0,  73,   0,   0,  98,
          0,   0,   0,   0,  67,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0], dtype=torch.int32)