In [None]:
from collections import Counter
from math import floor
from pathlib import Path
from random import Random
from typing import Any, Iterator

from datasets import Dataset
from tokenizers import Tokenizer, models, trainers

In [None]:
from foresight.tokenizers import PreTrainedTokenizerFastWithPositionIDPadding

## Dummy Data

In [None]:
NUM_TIMELINES = 10000
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

RANDOM_SEED = Random(23)
MAX_NUM_SAMPLES = 3

SEPARATOR_TOKEN = "<SEP>"
PADDING_TOKEN = "<PAD>"
UNKNOWN_TOKEN = "<UNK>"

OUTPUT_DIR = Path.cwd() / "outputs"
SAVE_TOKENIZER_PATH = OUTPUT_DIR / "tokenizer"
SAVE_ENCODED_DATASET_PATH = OUTPUT_DIR / "encoded_dataset"

In [None]:
def get_samples(num_timelines: int) -> Iterator[dict[str, Any]]:
    for _ in range(num_timelines):
        start_idx = RANDOM_SEED.randint(0, floor(len(LETTERS) / 2))
        num_skips = RANDOM_SEED.randint(0, 2)
        num_samples = RANDOM_SEED.randint(1, 3)

        timeline: list[list[str]] = []
        char_diffs: list[int] = []
        start_char_idx = start_idx
        while start_char_idx < len(LETTERS):
            end_char_idx = min(start_char_idx + num_samples, len(LETTERS))
            timeline.append(list(LETTERS[start_char_idx:end_char_idx]))
            char_diffs.append(num_skips)

            start_char_idx = end_char_idx + num_skips

        yield (
            {
                "timeline": timeline,
                "char_diffs": char_diffs,
                "start_idx": start_idx,
                "num_blanks": num_skips,
                "num_samples": num_samples,
            }
        )

In [None]:
dataset = Dataset.from_generator(lambda: get_samples(NUM_TIMELINES))
next(iter(dataset))

In [None]:
def batched_timeline_to_tokens(
    batched_samples: dict[str, list], separator: str
) -> dict[str, list]:
    batched_samples["tokens"] = [
        [
            token
            for timestep, char_diff in zip(timeline, char_diffs)
            for token in [f"char_diff_{char_diff}"] + timestep + [separator]
        ]
        for timeline, char_diffs in zip(
            batched_samples["timeline"], batched_samples["char_diffs"]
        )
    ]
    return batched_samples


dataset = dataset.map(
    lambda batch: batched_timeline_to_tokens(batch, SEPARATOR_TOKEN), batched=True
)
next(iter(dataset))

In [None]:
def batched_prepend_token(
    batched_samples: dict[str, list], token: str
) -> dict[str, list]:
    for idx, _ in enumerate(batched_samples["tokens"]):
        batched_samples["tokens"][idx].insert(0, token)
    return batched_samples


def batched_prepend_static_feature_token(
    batched_samples: dict[str, list], key: str
) -> dict[str, list]:
    for idx, _ in enumerate(batched_samples["tokens"]):
        batched_samples["tokens"][idx].insert(0, f"{key}_{batched_samples[key][idx]}")
    return batched_samples


dataset = dataset.map(
    lambda batch: batched_prepend_token(batch, SEPARATOR_TOKEN),
    batched=True,
)
dataset = dataset.map(
    lambda batch: batched_prepend_static_feature_token(batch, "start_idx"),
    batched=True,
)
dataset = dataset.map(
    lambda batch: batched_prepend_static_feature_token(batch, "num_blanks"),
    batched=True,
)
dataset = dataset.map(
    lambda batch: batched_prepend_static_feature_token(batch, "num_samples"),
    batched=True,
)
next(iter(dataset))

# Make tokenizer

Adapted from https://huggingface.co/learn/nlp-course/chapter6/8?fw=pt#building-a-wordpiece-tokenizer-from-scratch

In [None]:
token_count = Counter(
    token for tokens in dataset["train"]["tokens"] for token in tokens
)
print(token_count)

In [None]:
tokenizer = Tokenizer(models.WordLevel(unk_token=UNKNOWN_TOKEN))
# Separator token is already in the dataset
trainer = trainers.WordLevelTrainer(special_tokens=[UNKNOWN_TOKEN, PADDING_TOKEN])

In [None]:
tokenizer.train_from_iterator(dataset["train"]["tokens"], trainer=trainer)

In [None]:
encoding = tokenizer.encode(dataset["train"][0]["tokens"], is_pretokenized=True)
print(list(zip(encoding.tokens, encoding.ids)))

In [None]:
pretrained_fast_tokenizer = PreTrainedTokenizerFastWithPositionIDPadding(
    tokenizer_object=tokenizer,
    unk_token=UNKNOWN_TOKEN,
    pad_token=PADDING_TOKEN,
    sep_token=SEPARATOR_TOKEN,
)

In [None]:
encoded_sample = pretrained_fast_tokenizer(
    dataset["train"][0]["tokens"], is_split_into_words=True
)
encoded_sample

In [None]:
pretrained_fast_tokenizer.save_pretrained(SAVE_TOKENIZER_PATH)
reloaded_tokenizer = PreTrainedTokenizerFastWithPositionIDPadding.from_pretrained(
    SAVE_TOKENIZER_PATH
)

In [None]:
encoded_sample_reloaded = reloaded_tokenizer(
    dataset["train"][0]["tokens"], is_split_into_words=True
)
assert encoded_sample == encoded_sample_reloaded

In [None]:
dataset

In [None]:
encoded_dataset = dataset.map(
    lambda batch: pretrained_fast_tokenizer(
        batch["tokens"], is_split_into_words=True, return_token_type_ids=False
    ),
    batched=True,
    remove_columns=[
        "timeline",
        "char_diffs",
        "start_idx",
        "num_blanks",
        "num_samples",
        "tokens",
    ],
)
encoded_dataset

In [None]:
for key, value in encoded_dataset["train"][0].items():
    print(key, value[:10] if type(value) == list else value)

In [None]:
encoded_dataset.save_to_disk(SAVE_ENCODED_DATASET_PATH)