In [None]:
from collections import Counter
from math import floor
from pathlib import Path
from random import Random
from typing import Any, Iterator

from datasets import Dataset
from tokenizers import Tokenizer, models, trainers

In [None]:
from foresight.tokenizers import PreTrainedTokenizerFastWithPositionIDPadding

## Dummy Data

In [None]:
NUM_TIMELINES = 10000
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
RANDOM_SEED = Random(23)
MAX_NUM_SAMPLES = 3

SEPARATOR_TOKEN = "<SEP>"
PADDING_TOKEN = "<PAD>"
UNKNOWN_TOKEN = "<UNK>"

OUTPUT_DIR = Path.cwd() / "outputs"
SAVE_TOKENIZER_PATH = OUTPUT_DIR / "tokenizer"
SAVE_ENCODED_DATASET_PATH = OUTPUT_DIR / "encoded_dataset"

In [None]:
def get_samples(num_timelines: int) -> Iterator[dict[str, Any]]:
    for _ in range(num_timelines):
        start_idx = RANDOM_SEED.randint(0, floor(len(LETTERS) / 2))
        skip_odd = RANDOM_SEED.choice([True, False])
        num_samples = RANDOM_SEED.randint(1, MAX_NUM_SAMPLES)

        timeline: list[list[str]] = []
        for char_idx in range(start_idx, len(LETTERS), num_samples):
            if skip_odd and char_idx % 2 == 1:
                timeline.append([])
            else:
                timeline.append(list(LETTERS[char_idx : char_idx + num_samples]))

        yield (
            {
                "timeline": timeline,
                "start_idx": start_idx,
                "skip_odd": skip_odd,
                "num_samples": num_samples,
            }
        )

In [None]:
dataset = Dataset.from_generator(lambda: get_samples(NUM_TIMELINES))
next(iter(dataset))

In [None]:
def batched_timeline_to_tokens(
    batched_samples: dict[str, list], separator: str
) -> dict[str, list]:
    batched_samples["tokens"] = [
        [
            timestep_value
            for timestep in timeline
            for timestep_value in [separator] + timestep
        ]
        for timeline in batched_samples["timeline"]
    ]
    return batched_samples


dataset = dataset.map(
    lambda batch: batched_timeline_to_tokens(batch, SEPARATOR_TOKEN), batched=True
)
dataset[0]["timeline"][:10], dataset[0]["tokens"][:10]

In [None]:
def batched_insert_static_feature_token(
    batched_samples: dict[str, list], key: str, insert_idx: int
) -> dict[str, list]:
    for idx, _ in enumerate(batched_samples["tokens"]):
        batched_samples["tokens"][idx].insert(
            insert_idx, f"{key}_{batched_samples[key][idx]}"
        )
    return batched_samples


dataset = dataset.map(
    lambda batch: batched_insert_static_feature_token(batch, "start_idx", insert_idx=0),
    batched=True,
)
dataset = dataset.map(
    lambda batch: batched_insert_static_feature_token(batch, "skip_odd", insert_idx=1),
    batched=True,
)
dataset = dataset.map(
    lambda batch: batched_insert_static_feature_token(
        batch, "num_samples", insert_idx=2
    ),
    batched=True,
)
dataset[0]["tokens"][:10]

## Add position IDs

In [None]:
def batched_add_position_ids(
    batched_samples: dict[str, list], separators: set[str]
) -> dict[str, list]:
    batched_samples["position_ids"] = []
    for tokens in batched_samples["tokens"]:
        position_ids = []
        cnt = 0
        for token in tokens:
            if token in separators:
                cnt += 1
            position_ids.append(cnt)
        batched_samples["position_ids"].append(position_ids)
    return batched_samples


dataset = dataset.map(
    lambda batch: batched_add_position_ids(batch, {SEPARATOR_TOKEN}), batched=True
)
print(list(zip(dataset[0]["tokens"], dataset[0]["position_ids"])))

In [None]:
dataset = dataset.train_test_split(test_size=0.2)
dataset

# Make tokenizer

Adapted from https://huggingface.co/learn/nlp-course/chapter6/8?fw=pt#building-a-wordpiece-tokenizer-from-scratch

In [None]:
token_count = Counter(
    token for tokens in dataset["train"]["tokens"] for token in tokens
)
print(token_count)

In [None]:
tokenizer = Tokenizer(models.WordLevel(unk_token=UNKNOWN_TOKEN))
# Separator token is already in the dataset
trainer = trainers.WordLevelTrainer(special_tokens=[UNKNOWN_TOKEN, PADDING_TOKEN])

In [None]:
tokenizer.train_from_iterator(dataset["train"]["tokens"], trainer=trainer)

In [None]:
encoding = tokenizer.encode(dataset["train"][0]["tokens"], is_pretokenized=True)
print(list(zip(encoding.tokens, encoding.ids)))

In [None]:
pretrained_fast_tokenizer = PreTrainedTokenizerFastWithPositionIDPadding(
    tokenizer_object=tokenizer,
    unk_token=UNKNOWN_TOKEN,
    pad_token=PADDING_TOKEN,
    sep_token=SEPARATOR_TOKEN,
)

In [None]:
encoded_sample = pretrained_fast_tokenizer(
    dataset["train"][0]["tokens"], is_split_into_words=True
)
encoded_sample

In [None]:
pretrained_fast_tokenizer.save_pretrained(SAVE_TOKENIZER_PATH)
reloaded_tokenizer = PreTrainedTokenizerFastWithPositionIDPadding.from_pretrained(
    SAVE_TOKENIZER_PATH
)

In [None]:
encoded_sample_reloaded = reloaded_tokenizer(
    dataset["train"][0]["tokens"], is_split_into_words=True
)
assert encoded_sample == encoded_sample_reloaded

In [None]:
encoded_dataset = dataset.map(
    lambda batch: pretrained_fast_tokenizer(
        batch["tokens"], is_split_into_words=True, return_token_type_ids=False
    ),
    batched=True,
    remove_columns=["timeline", "start_idx", "skip_odd", "num_samples", "tokens"],
)
encoded_dataset

In [None]:
for key, value in encoded_dataset["train"][0].items():
    print(key, value[:10] if type(value) == list else value)

In [None]:
encoded_dataset.save_to_disk(SAVE_ENCODED_DATASET_PATH)