# Tokenize DNA using Byte-Pair Encoding

In [1]:
import sys

sys.path.append('..')

In [2]:
from collections import defaultdict

from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from tqdm import tqdm

from adna.pylib import consts, datasets

## What characters are used?

In [3]:
counts_dataset = datasets.ADnaDataset()

In [4]:
def count_bases(dataset):
    chars = defaultdict(int)
    for seq in tqdm(dataset):
        for base in seq:
            chars[base] += 1
    return chars


count_bases(counts_dataset)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 478722/478722 [00:04<00:00, 104429.18it/s]


defaultdict(int,
            {'T': 17125866,
             'C': 16246533,
             'A': 14510124,
             'G': 8819707,
             'N': 1406})

## Train the tokenizer

In [5]:
tokenizer = ByteLevelBPETokenizer()

In [6]:
aug_seqs = datasets.ADnaDataset(
    rev_comp_rate=consts.REV_COMP_RATE, to_n_rate=consts.TO_N_RATE
)

In [7]:
tokenizer.train_from_iterator(
    aug_seqs,
    vocab_size=consts.VOCAB_SIZE,
    min_frequency=consts.MIN_FREQ,
    special_tokens=consts.SPECIAL_TOKENS,
)






In [8]:
tokenizer.post_processor = BertProcessing(
    (consts.EOS, tokenizer.token_to_id(consts.EOS)),
    (consts.BOS, tokenizer.token_to_id(consts.BOS)),
)

## Get tokenized lengths

In [9]:
lengths = defaultdict(int)

STEP = 1024

SEQS = [s for s in aug_seqs]

In [10]:
for i in tqdm(range(0, len(SEQS), STEP)):
    batch = tokenizer.encode_batch(SEQS[i:i + STEP])
    for tokens in batch:
        t_len = len(tokens)
        lengths[t_len] += 1

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:06<00:00, 75.71it/s]


In [11]:
sorted(lengths.items())

[(3, 1),
 (4, 21),
 (5, 214),
 (6, 866),
 (7, 1954),
 (8, 3795),
 (9, 6186),
 (10, 9084),
 (11, 12281),
 (12, 15616),
 (13, 19500),
 (14, 22197),
 (15, 24753),
 (16, 26516),
 (17, 28207),
 (18, 28457),
 (19, 27733),
 (20, 27018),
 (21, 25533),
 (22, 23775),
 (23, 21760),
 (24, 20248),
 (25, 17895),
 (26, 15874),
 (27, 13483),
 (28, 11610),
 (29, 9802),
 (30, 8447),
 (31, 7088),
 (32, 6037),
 (33, 4980),
 (34, 4091),
 (35, 3497),
 (36, 2885),
 (37, 2434),
 (38, 2215),
 (39, 1903),
 (40, 1627),
 (41, 1393),
 (42, 1275),
 (43, 1113),
 (44, 1040),
 (45, 989),
 (46, 887),
 (47, 888),
 (48, 819),
 (49, 808),
 (50, 801),
 (51, 773),
 (52, 728),
 (53, 659),
 (54, 591),
 (55, 573),
 (56, 638),
 (57, 520),
 (58, 512),
 (59, 524),
 (60, 527),
 (61, 457),
 (62, 484),
 (63, 447),
 (64, 404),
 (65, 344),
 (66, 252),
 (67, 250),
 (68, 177),
 (69, 119),
 (70, 70),
 (71, 35),
 (72, 22),
 (73, 12),
 (74, 5),
 (75, 3)]

Given the above I'm going to use a maximum sequence length of x tokens below.

In [12]:
consts.MAX_LENGTH

100

## Finalize the tokenizer

In [13]:
tokenizer.enable_padding(
    pad_token=consts.PAD,
    pad_id=tokenizer.token_to_id(consts.PAD),
    length=consts.MAX_LENGTH,
)

In [14]:
encoded = tokenizer.encode(SEQS[0])
encoded.tokens

['<s>',
 'GGAGG',
 'AAGG',
 'ACTC',
 'ACTGCC',
 'NCG',
 'TAATCGCGACATTTTAATGGAGTAGTTCGGTTGG',
 'TTCTC',
 'TATT',
 'TTTGGGCAA',
 'TATC',
 'ACTCTACACCTCC',
 'AGTGCGGC',
 'GAG',
 'TACACAATTGGTTG',
 'A',
 '</s>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 

In [15]:
tokenizer.save(str(consts.SUB_DIR / 'tokenizer.json'))