In [None]:
import tokenizers
import torch
from datasets import load_dataset

import deepchopper
from deepchopper.models import KmerPreTokenizer

In [None]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

In [None]:
from tokenizers import NormalizedString, PreTokenizedString, Regex, Tokenizer
from tokenizers.decoders import Decoder
from tokenizers.models import BPE
from tokenizers.normalizers import Normalizer
from tokenizers.pre_tokenizers import PreTokenizer

import deepchopper


def compute_tokens_to_ids(kmer_size: int) -> tuple[dict[str, int], list[str]]:
    kmers_tokens = deepchopper.generate_kmers(deepchopper.default.BASES, kmer_size)
    standard_tokens = kmers_tokens

    unknown_token = "<UNK>"
    padding_token = "<PAD>"
    eos_token = "<EOS>"
    bos_token = "<BOS>"
    sep_token = "<SEP>"

    specical_tokens = [padding_token, unknown_token, eos_token, bos_token, sep_token]
    all_tokens = standard_tokens + specical_tokens
    tokens_to_ids = {tok: i for i, tok in enumerate(all_tokens)}
    return tokens_to_ids, all_tokens


class KmerPreTokenizer:
    def __init__(self, kmer_size: int, *, overlap: bool):
        self.kmer_size = kmer_size
        self.overlap = overlap

    def kmer_split(self, i: int, normalized_string: NormalizedString) -> list[NormalizedString]:
        return [
            normalized_string[start:end]
            for (_token, (start, end)) in deepchopper.seq_to_kmers_and_offset(
                sequence, self.kmer_size, self.overlap
            )
        ]

    def pre_tokenize(self, pretok: PreTokenizedString):
        # Let's call split on the PreTokenizedString to split using `self.jieba_split`
        pretok.split(self.kmer_split)


class KmerDecoder:
    def decode(self, tokens: list[str]) -> str:
        return "".join(tokens)

In [None]:
from rich.console import Console
from rich.text import Text


def hight_text(text: str, start: int, end: int):
    text = Text(text)
    console = Console()
    text.stylize("bold magenta", start, end)
    console.print(text)

In [None]:
def test_pre_tokenize_str_no_overlap():
    tokenizer = KmerPreTokenizer(3, overlap=False)
    sequence = "ATCGGCC"
    expected_output = [("ATC", (0, 3)), ("GGC", (3, 6))]
    res = tokenizer.pre_tokenize_str(sequence)
    assert res == expected_output

In [None]:
data_files = {"train": "../tests/data/test_input.parquet"}
num_proc = 8
train_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[:70%]"
)
val_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[70%:90%]"
)
test_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[90%:]"
)

print(f"train_dataset: {train_dataset}")
print(f"val_dataset: {val_dataset}")
print(f"test_dataset: {test_dataset}")

In [None]:
train_dataset["seq"][0]
train_dataset["id"][0]
# train_dataset['qual'][0]
train_dataset["target"][0]

In [None]:
hight_text(train_dataset["seq"][0], *train_dataset["target"][0])

In [None]:
# deepchopper.seq_to_kmers(train_dataset['seq'][0], 5, overlap=False)

In [None]:
# test_dataset.map(lambda x : partial(deepchopper.seq_to_kmers, overlap=False, k=5)(x['seq']))
# test_dataset.map(lambda x : print(x['seq']))

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import WordLevel

tokenizer = Tokenizer(WordLevel())
tokenizer.pre_tokenizer = PreTokenizer.custom(KmerPreTokenizer(3, overlap=True))

In [None]:
ts = train_dataset["seq"][0]

In [None]:
tokenizer.pre_tokenizer.pre_tokenize_str(ts)

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import BPE

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))

In [None]:
from tokenizers.trainers import BpeTrainer

trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])

In [None]:
from tokenizers.pre_tokenizers import ByteLevel

tokenizer.pre_tokenizer = ByteLevel()

In [None]:
tokenizer.train?

In [44]:
import torch
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

In [62]:
tokenizer.save_pretrained("./dnabert2_117M")

('./dnabert2_117M/tokenizer_config.json',
 './dnabert2_117M/special_tokens_map.json',
 './dnabert2_117M/tokenizer.json')

In [46]:
tt = tokenizer(ts)

In [None]:
output = tokenizer.encode(ts)

In [None]:
output.ids

In [None]:
tokenizer.convert_tokens_to_string?

In [None]:
tokenizer.decode(output)

In [47]:
tokenizer.save_pretrained

{'input_ids': [1, 4085, 513, 27, 229, 2886, 3551, 222, 671, 131, 728, 1403, 145, 73, 1154, 2482, 36, 197, 135, 421, 1310, 103, 183, 23, 430, 134, 76, 1973, 634, 24, 1033, 68, 73, 45, 949, 59, 3153, 2595, 2219, 301, 1697, 2246, 2470, 1008, 78, 30, 250, 73, 67, 115, 25, 268, 236, 513, 460, 2870, 135, 77, 85, 519, 107, 66, 50, 52, 574, 29, 1027, 18, 71, 480, 30, 535, 45, 157, 423, 2855, 719, 31, 1848, 57, 572, 3189, 539, 901, 29, 176, 41, 423, 532, 674, 36, 423, 423, 948, 674, 281, 83, 105, 185, 605, 68, 188, 1689, 86, 719, 671, 105, 27, 703, 200, 24, 361, 22, 161, 2436, 33, 102, 314, 128, 21, 138, 25, 365, 36, 249, 1317, 71, 131, 1543, 124, 34, 409, 76, 103, 772, 247, 82, 349, 918, 772, 515, 32, 330, 182, 24, 128, 49, 263, 595, 191, 212, 1182, 61, 53, 860, 53, 1028, 410, 295, 3818, 801, 176, 30, 253, 628, 62, 1702, 330, 65, 69, 610, 1300, 3883, 1800, 29, 191, 28, 296, 2389, 28, 170, 53, 2356, 31, 319, 42, 45, 2145, 183, 76, 49, 568, 185, 98, 206, 86, 140, 249, 146, 605, 1523, 20, 369, 23

In [48]:
tokenizer.save("data/tokenizer-wiki.json")

AttributeError: 'PreTrainedTokenizerFast' object has no attribute 'save'

In [51]:
token_test_dataset = test_dataset.map(lambda x : tokenizer(x['seq']))

In [50]:
test_dataset

Dataset({
    features: ['id', 'seq', 'qual', 'target'],
    num_rows: 500
})

In [57]:
tokenizer.convert_ids_to_tokens(token_test_dataset['input_ids'][0])

['[CLS]',
 'AA',
 'GTACTG',
 'CGG',
 'CCGTG',
 'TGGG',
 'TGAGTT',
 'GGCTG',
 'CC',
 'GGTGA',
 'GTTGGGG',
 'TGCC',
 'GGTG',
 'GAGTC',
 'GTGTT',
 'GGTCC',
 'TCAGAA',
 'TCCCC',
 'GCGTA',
 'GCC',
 'GCTG',
 'CCTCC',
 'TCC',
 'TACC',
 'CTC',
 'GCCATGTT',
 'CACCCC',
 'GG',
 'TCTGA',
 'GTA',
 'CGA',
 'CAGG',
 'GCAA',
 'TGTGAA',
 'TACTTTT',
 'TCACC',
 'CGAA',
 'GGAA',
 'GATTATT',
 'CAA',
 'GTGGAA',
 'TATG',
 'CCATT',
 'GAGG',
 'CTATCAA',
 'GCTT',
 'GGTT',
 'CTA',
 'CAGC',
 'CATT',
 'GGGA',
 'TCCAGA',
 'CATA',
 'TCGAGG',
 'GTGTGTG',
 'CCTA',
 'GCTGTG',
 'GAGAA',
 'GAGAA',
 'TTA',
 'CTTCC',
 'CCCA',
 'CTGA',
 'TGGA',
 'GCC',
 'CAGCA',
 'GCATT',
 'GA',
 'GAAAATT',
 'GTAGAGA',
 'GATT',
 'GATG',
 'CTCA',
 'CA',
 'TAGGTT',
 'GTGCCA',
 'TGAGTG',
 'GCTGCTG',
 'CTGATG',
 'CTAA',
 'GACTT',
 'TAATT',
 'GATAAA',
 'GCCAGA',
 'GTGGAGA',
 'CACAGAA',
 'CCTT',
 'GGTT',
 'CACCTA',
 'CAA',
 'TGAGA',
 'CAATGA',
 'CAGTG',
 'GAGA',
 'GTGTGA',
 'CCCAA',
 'GCTGTG',
 'TCCAA',
 'TCTG',
 'GCTT',
 'TGCA',
 'GTTTGGA',
 'GA

In [58]:
test_dataset['seq'][0]

'AAGTACTGCGGCCGTGTGGGTGAGTTGGCTGCCGGTGAGTTGGGGTGCCGGTGGAGTCGTGTTGGTCCTCAGAATCCCCGCGTAGCCGCTGCCTCCTCCTACCCTCGCCATGTTCACCCCGGTCTGAGTACGACAGGGCAATGTGAATACTTTTTCACCCGAAGGAAGATTATTCAAGTGGAATATGCCATTGAGGCTATCAAGCTTGGTTCTACAGCCATTGGGATCCAGACATATCGAGGGTGTGTGCCTAGCTGTGGAGAAGAGAATTACTTCCCCCACTGATGGAGCCCAGCAGCATTGAGAAAATTGTAGAGAGATTGATGCTCACATAGGTTGTGCCATGAGTGGCTGCTGCTGATGCTAAGACTTTAATTGATAAAGCCAGAGTGGAGACACAGAACCTTGGTTCACCTACAATGAGACAATGACAGTGGAGAGTGTGACCCAAGCTGTGTCCAATCTGGCTTTGCAGTTTGGAGAAGAAGATGCAGATCCAGGTGCCATGTCTTGCCCTTCTGGAGTAGCATTATATTTGGAGTTGATGAGAAAGGACCCCAGCTGTTATGGACCCATCTGGGACCTTTGTACAGTGTGATGCTCGAGCAATTGGCTCTGCTTCAGAGGGTGCCCAGAGCTCCTTGCAAGAAGTTTACCACAAGTCTATGACTTTGAAAGAAGCCATTAGTCATTCACTCATCATCCTCAAACAAGTAATGGAGGAGAAGCTGAATGCAACAAACATTGAGCTAGCCACAGTGCAGCCTGGCCAGAATTTCCACATATCAAGGAAGAACACTGAAGAGGTTATACAAGGGCATTACCTTAAGGAATCCTGATCTCAGAACTTCTTTCTCTGGGACAATCTCAGTTCTAATTATGTCCTTAAATTCATCTCCAGCTCCTGTTCCTTGGAAAATCTCCATTGTATGTGCATTCTTAATGATGTCTGTCAAAGGCAGTTCTGAAAATAAAGAAAATCTTTAAAATAAAAAAAAA

In [66]:
tt2 = tokenizers.Tokenizer.from_file("dnabert2_117M/tokenizer.json")

In [83]:
test_dataset

Dataset({
    features: ['id', 'seq', 'qual', 'target'],
    num_rows: 500
})

In [71]:
ts2 = tt2.encode(ts)

In [82]:
ts2.tokens

['[CLS]',
 'GCAGCTA',
 'TGAATG',
 'CAA',
 'GGCCA',
 'CAAGGTG',
 'GATGGAA',
 'GAGTT',
 'GTGGAA',
 'CCAAA',
 'GAGCTG',
 'TCTTCCA',
 'GAGAA',
 'GATT',
 'TCGAGA',
 'TAAGTC',
 'GCC',
 'CATCA',
 'GTGAA',
 'CAAGA',
 'TATTGTT',
 'GGTG',
 'GCATT',
 'TGA',
 'TGAGAA',
 'CGTT',
 'CCAA',
 'GATTATT',
 'GACAGA',
 'TTA',
 'GTGAAAA',
 'GTAA',
 'GATT',
 'GAAA',
 'TCATGA',
 'CTGA',
 'CCGTAA',
 'GTGGCAA',
 'GAAAGG',
 'GCTTTT',
 'GCCTTTG',
 'TAACCTT',
 'TGACGA',
 'CCATGA',
 'CTCC',
 'GTG',
 'GATAA',
 'GATT',
 'GTCA',
 'TTCA',
 'GAA',
 'TACCA',
 'TACTG',
 'TGAATG',
 'GCCACA',
 'TCTTTATT',
 'GTGAA',
 'GTTA',
 'GAAAA',
 'GCCCTG',
 'TCAAA',
 'GCAA',
 'GAGA',
 'TGAA',
 'TCAGTG',
 'CTT',
 'CTCCAGCC',
 'AAA',
 'GAGG',
 'TCGAA',
 'GTG',
 'GTTCTG',
 'GAAA',
 'CTTTG',
 'GTGGTG',
 'GTCGTG',
 'GAGGTG',
 'GTT',
 'TCGGTG',
 'GGAA',
 'TGACAA',
 'CTCGG',
 'TCGTG',
 'GAGGAAA',
 'CTT',
 'CAGTG',
 'GTC',
 'GTGGTG',
 'GCTTTG',
 'GTGGCA',
 'GCC',
 'GTGGTG',
 'GTGGTG',
 'GATATG',
 'GTGGCA',
 'GTGGG',
 'GATG',
 'GCTA',
 'TAATG',

In [75]:
ts2.offsets

[(0, 0),
 (0, 7),
 (7, 13),
 (13, 16),
 (16, 21),
 (21, 28),
 (28, 35),
 (35, 40),
 (40, 46),
 (46, 51),
 (51, 57),
 (57, 64),
 (64, 69),
 (69, 73),
 (73, 79),
 (79, 85),
 (85, 88),
 (88, 93),
 (93, 98),
 (98, 103),
 (103, 110),
 (110, 114),
 (114, 119),
 (119, 122),
 (122, 128),
 (128, 132),
 (132, 136),
 (136, 143),
 (143, 149),
 (149, 152),
 (152, 159),
 (159, 163),
 (163, 167),
 (167, 171),
 (171, 177),
 (177, 181),
 (181, 187),
 (187, 194),
 (194, 200),
 (200, 206),
 (206, 213),
 (213, 220),
 (220, 226),
 (226, 232),
 (232, 236),
 (236, 239),
 (239, 244),
 (244, 248),
 (248, 252),
 (252, 256),
 (256, 259),
 (259, 264),
 (264, 269),
 (269, 275),
 (275, 281),
 (281, 289),
 (289, 294),
 (294, 298),
 (298, 303),
 (303, 309),
 (309, 314),
 (314, 318),
 (318, 322),
 (322, 326),
 (326, 332),
 (332, 335),
 (335, 343),
 (343, 346),
 (346, 350),
 (350, 355),
 (355, 358),
 (358, 364),
 (364, 368),
 (368, 373),
 (373, 379),
 (379, 385),
 (385, 391),
 (391, 394),
 (394, 400),
 (400, 404),
 (40