# Tokenize DNA using Byte-Pair Encoding

In [1]:
import sys

sys.path.append('..')

In [2]:
from collections import defaultdict

from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from tqdm import tqdm

from adna.pylib import consts, datasets

In [6]:
SEQS, _ = datasets.read_dataset()

## What characters are used?

In [4]:
def count_bases(seqs):
    chars = defaultdict(int)
    for seq in tqdm(seqs):
        for base in seq:
            chars[base] += 1
    return chars


count_bases(SEQS)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 478722/478722 [00:04<00:00, 97595.34it/s]


defaultdict(int,
            {'T': 17125866,
             'C': 16246533,
             'A': 14510124,
             'G': 8819707,
             'N': 1406})

## Train the tokenizer

In [5]:
tokenizer = ByteLevelBPETokenizer()

In [7]:
tokenizer.train_from_iterator(
    SEQS,
    vocab_size=consts.VOCAB_SIZE,
    min_frequency=consts.MIN_FREQ,
    special_tokens=consts.SPECIAL_TOKENS,
)






In [8]:
tokenizer.post_processor = BertProcessing(
    (consts.EOS, tokenizer.token_to_id(consts.EOS)),
    (consts.BOS, tokenizer.token_to_id(consts.BOS)),
)

## Get tokenized lengths

In [9]:
lengths = defaultdict(int)

STEP = 1024

SEQS = [s for s in aug_seqs]

In [10]:
for i in tqdm(range(0, len(SEQS), STEP)):
    batch = tokenizer.encode_batch(SEQS[i:i + STEP])
    for tokens in batch:
        t_len = len(tokens)
        lengths[t_len] += 1

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 468/468 [00:05<00:00, 85.12it/s]


In [11]:
sorted(lengths.items())

[(3, 4),
 (4, 106),
 (5, 997),
 (6, 3777),
 (7, 8654),
 (8, 15568),
 (9, 23642),
 (10, 31235),
 (11, 35884),
 (12, 37164),
 (13, 36703),
 (14, 34006),
 (15, 32256),
 (16, 29134),
 (17, 25957),
 (18, 23127),
 (19, 20332),
 (20, 17025),
 (21, 14585),
 (22, 12201),
 (23, 10293),
 (24, 8489),
 (25, 7324),
 (26, 6153),
 (27, 4882),
 (28, 4035),
 (29, 3546),
 (30, 3022),
 (31, 2433),
 (32, 2112),
 (33, 1951),
 (34, 2212),
 (35, 1603),
 (36, 1203),
 (37, 915),
 (38, 821),
 (39, 1075),
 (40, 848),
 (41, 725),
 (42, 724),
 (43, 757),
 (44, 795),
 (45, 650),
 (46, 547),
 (47, 473),
 (48, 422),
 (49, 505),
 (50, 587),
 (51, 581),
 (52, 436),
 (53, 374),
 (54, 537),
 (55, 593),
 (56, 687),
 (57, 620),
 (58, 365),
 (59, 366),
 (60, 422),
 (61, 448),
 (62, 412),
 (63, 382),
 (64, 227),
 (65, 252),
 (66, 271),
 (67, 111),
 (68, 59),
 (69, 48),
 (70, 53),
 (71, 13),
 (72, 2),
 (73, 1),
 (74, 2),
 (75, 1)]

Given the above I'm going to use a maximum sequence length of x tokens below.

In [12]:
consts.MAX_LENGTH

80

## Finalize the tokenizer

In [13]:
tokenizer.enable_padding(
    pad_token=consts.PAD,
    pad_id=tokenizer.token_to_id(consts.PAD),
    length=consts.MAX_LENGTH,
)

In [14]:
encoded = tokenizer.encode(SEQS[0])
encoded.tokens

['<s>',
 'TCAACCAATTGTG',
 'TAC',
 'TCGCC',
 'GCACTGGAGGTGTAG',
 'AGTG',
 'ATATT',
 'GCCC',
 'AAAA',
 'ATAG',
 'AG',
 'AACCAACC',
 'GAACTACTCCATTAAAATG',
 'TCGCGATTACGAGGCAG',
 'TGAG',
 'TCC',
 'TTCCTCC',
 '</s>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>']

In [15]:
tokenizer.save(str(consts.MT_DIR / 'tokenizer.json'))