In [None]:
import tokenizers
import torch
from datasets import load_dataset

import deepchopper
from deepchopper.models import KmerPreTokenizer

In [None]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

In [None]:
from tokenizers import NormalizedString, PreTokenizedString, Regex, Tokenizer
from tokenizers.decoders import Decoder
from tokenizers.models import BPE
from tokenizers.normalizers import Normalizer
from tokenizers.pre_tokenizers import PreTokenizer

import deepchopper


def compute_tokens_to_ids(kmer_size: int) -> tuple[dict[str, int], list[str]]:
    kmers_tokens = deepchopper.generate_kmers(deepchopper.default.BASES, kmer_size)
    standard_tokens = kmers_tokens

    unknown_token = "<UNK>"
    padding_token = "<PAD>"
    eos_token = "<EOS>"
    bos_token = "<BOS>"
    sep_token = "<SEP>"

    specical_tokens = [padding_token, unknown_token, eos_token, bos_token, sep_token]
    all_tokens = standard_tokens + specical_tokens
    tokens_to_ids = {tok: i for i, tok in enumerate(all_tokens)}
    return tokens_to_ids, all_tokens


class KmerPreTokenizer:
    def __init__(self, kmer_size: int, *, overlap: bool):
        self.kmer_size = kmer_size
        self.overlap = overlap

    def kmer_split(self, i: int, normalized_string: NormalizedString) -> list[NormalizedString]:
        return [
            normalized_string[start:end]
            for (_token, (start, end)) in deepchopper.seq_to_kmers_and_offset(
                sequence, self.kmer_size, self.overlap
            )
        ]

    def pre_tokenize(self, pretok: PreTokenizedString):
        # Let's call split on the PreTokenizedString to split using `self.jieba_split`
        pretok.split(self.kmer_split)


class KmerDecoder:
    def decode(self, tokens: list[str]) -> str:
        return "".join(tokens)

In [None]:
from rich.console import Console
from rich.text import Text


def hight_text(text: str, start: int, end: int):
    text = Text(text)
    console = Console()
    text.stylize("bold magenta", start, end)
    console.print(text)

In [None]:
def test_pre_tokenize_str_no_overlap():
    tokenizer = KmerPreTokenizer(3, overlap=False)
    sequence = "ATCGGCC"
    expected_output = [("ATC", (0, 3)), ("GGC", (3, 6))]
    res = tokenizer.pre_tokenize_str(sequence)
    assert res == expected_output

In [None]:
data_files = {"train": "../tests/data/test_input.parquet"}
num_proc = 8
train_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[:70%]"
)
val_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[70%:90%]"
)
test_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[90%:]"
)

print(f"train_dataset: {train_dataset}")
print(f"val_dataset: {val_dataset}")
print(f"test_dataset: {test_dataset}")

In [None]:
train_dataset["seq"][0]
train_dataset["id"][0]
# train_dataset['qual'][0]
train_dataset["target"][0]

In [None]:
hight_text(train_dataset["seq"][0], *train_dataset["target"][0])

In [None]:
# deepchopper.seq_to_kmers(train_dataset['seq'][0], 5, overlap=False)

In [None]:
# test_dataset.map(lambda x : partial(deepchopper.seq_to_kmers, overlap=False, k=5)(x['seq']))
# test_dataset.map(lambda x : print(x['seq']))

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import WordLevel

tokenizer = Tokenizer(WordLevel())
tokenizer.pre_tokenizer = PreTokenizer.custom(KmerPreTokenizer(3, overlap=True))

In [None]:
ts = train_dataset["seq"][0]

In [None]:
tokenizer.pre_tokenizer.pre_tokenize_str(ts)

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import BPE

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))

In [None]:
from tokenizers.trainers import BpeTrainer

trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])

In [None]:
from tokenizers.pre_tokenizers import ByteLevel

tokenizer.pre_tokenizer = ByteLevel()

In [None]:
tokenizer.train?

In [None]:
import torch
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

In [None]:
output = tokenizer.encode(ts)

In [None]:
output.ids

In [None]:
tokenizer.convert_tokens_to_string?

In [None]:
tokenizer.decode(output)