In [None]:
from itertools import islice
from pathlib import Path
from random import Random
from typing import Any, Iterator

from datasets import Dataset
from tokenizers import Tokenizer, models, trainers

In [None]:
from foresight.tokenizers import PreTrainedTokenizerFastWithPositionIDPadding

## Dummy Data

In [None]:
NUM_TIMELINES = 1000
YEAR_CUTOFF = 2024
VAL_YEAR_CUTOFF = 2020
CONDITIONS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

RANDOM_SEED = Random(23)
MAX_NUM_SAMPLES = 3

SEPARATOR_TOKEN = "<SEP>"
PADDING_TOKEN = "<PAD>"
UNKNOWN_TOKEN = "<UNK>"
EOS_TOKEN = "Z"

OUTPUT_DIR = Path.cwd() / "outputs"
SAVE_TOKENIZER_PATH = OUTPUT_DIR / "tokenizer"
SAVE_ENCODED_DATASET_PATH = OUTPUT_DIR / "encoded_dataset"
SAVE_RAW_DATASET_PATH = OUTPUT_DIR / "raw_dataset"

In [None]:
def get_samples(num_timelines: int, year_cutoff: int) -> Iterator[dict[str, Any]]:
    for _ in range(num_timelines):
        year_of_birth = RANDOM_SEED.randint(2005, 2010)
        sex = RANDOM_SEED.choice(["M", "F"])
        ethnicity = RANDOM_SEED.choice(["ETH_1", "ETH_2"])

        time_step = 1 if sex == "M" else 2
        num_samples = 1 if ethnicity == "ETH_1" else 2 if ethnicity == "ETH_2" else 3
        start_condition_idx = RANDOM_SEED.randint(0, len(CONDITIONS) - 1)
        timestamp = year_of_birth + RANDOM_SEED.randint(0, 10)

        timeline: list[list[str]] = []
        timestamps: list[int] = []

        while timestamp < year_cutoff and start_condition_idx < len(CONDITIONS):
            end_condition_idx = min(start_condition_idx + num_samples, len(CONDITIONS))
            timeline.append(list(CONDITIONS[start_condition_idx:end_condition_idx]))
            timestamps.append(timestamp)

            start_condition_idx = end_condition_idx
            timestamp += time_step

        yield (
            {
                "timeline": timeline,
                "timestamps": timestamps,
                "year_of_birth": year_of_birth,
                "sex": sex,
                "ethnicity": ethnicity,
            }
        )

In [None]:
dataset = Dataset.from_generator(lambda: get_samples(NUM_TIMELINES, YEAR_CUTOFF))
for data in islice(dataset, 5):
    print(data)

In [None]:
def batched_timeline_to_train_tokens(
    batched_samples: dict[str, list], separator: str, val_year_cutoff: int
) -> dict[str, list]:
    batched_time_diffs = [
        [t1 - t0 for t0, t1 in zip([year_of_birth] + timestamps, timestamps)]
        for year_of_birth, timestamps in zip(
            batched_samples["year_of_birth"], batched_samples["timestamps"]
        )
    ]
    batched_samples["tokens"] = [
        [
            token
            for condition, timestamp, time_diff in zip(timeline, timestamps, time_diffs)
            if timestamp < val_year_cutoff
            for token in [f"time_diff_{time_diff}"] + condition + [separator]
        ]
        for timeline, timestamps, time_diffs in zip(
            batched_samples["timeline"],
            batched_samples["timestamps"],
            batched_time_diffs,
        )
    ]
    return batched_samples


dataset = dataset.map(
    lambda batch: batched_timeline_to_train_tokens(
        batch, SEPARATOR_TOKEN, VAL_YEAR_CUTOFF
    ),
    batched=True,
)
for data in islice(dataset, 5):
    print(data)

In [None]:
def batched_prepend_token(
    batched_samples: dict[str, list], token: str
) -> dict[str, list]:
    for idx, _ in enumerate(batched_samples["tokens"]):
        batched_samples["tokens"][idx].insert(0, token)
    return batched_samples


def batched_prepend_static_feature_token(
    batched_samples: dict[str, list], key: str
) -> dict[str, list]:
    for idx, _ in enumerate(batched_samples["tokens"]):
        batched_samples["tokens"][idx].insert(0, f"{key}_{batched_samples[key][idx]}")
    return batched_samples


dataset = dataset.map(
    lambda batch: batched_prepend_token(batch, SEPARATOR_TOKEN),
    batched=True,
)
dataset = dataset.map(
    lambda batch: batched_prepend_static_feature_token(batch, "year_of_birth"),
    batched=True,
)
dataset = dataset.map(
    lambda batch: batched_prepend_static_feature_token(batch, "sex"),
    batched=True,
)
dataset = dataset.map(
    lambda batch: batched_prepend_static_feature_token(batch, "ethnicity"),
    batched=True,
)
next(iter(dataset))

In [None]:
dataset = dataset.train_test_split(test_size=0.1)

In [None]:
dataset.save_to_disk(SAVE_RAW_DATASET_PATH)

# Make tokenizer

Adapted from https://huggingface.co/learn/nlp-course/chapter6/8?fw=pt#building-a-wordpiece-tokenizer-from-scratch

In [None]:
inferred_tokens_count = {
    token for tokens in dataset["train"]["tokens"] for token in tokens
}
temporal_tokens = {f"time_diff_{i}" for i in range(11)} | {
    f"year_of_birth_{i}" for i in range(2000, 2030)
}

In [None]:
tokenizer = Tokenizer(models.WordLevel(unk_token=UNKNOWN_TOKEN))
# Separator and end of sequence tokens are already in the dataset
trainer = trainers.WordLevelTrainer(special_tokens=[UNKNOWN_TOKEN, PADDING_TOKEN])

In [None]:
tokenizer.train_from_iterator(inferred_tokens_count | temporal_tokens, trainer=trainer)

In [None]:
encoding = tokenizer.encode(dataset["train"][0]["tokens"], is_pretokenized=True)
print(list(zip(encoding.tokens, encoding.ids)))

In [None]:
pretrained_fast_tokenizer = PreTrainedTokenizerFastWithPositionIDPadding(
    tokenizer_object=tokenizer,
    unk_token=UNKNOWN_TOKEN,
    pad_token=PADDING_TOKEN,
    sep_token=SEPARATOR_TOKEN,
    eos_token=EOS_TOKEN,
)

In [None]:
encoded_sample = pretrained_fast_tokenizer(
    dataset["train"][0]["tokens"], is_split_into_words=True
)
encoded_sample

In [None]:
pretrained_fast_tokenizer.save_pretrained(SAVE_TOKENIZER_PATH)
reloaded_tokenizer = PreTrainedTokenizerFastWithPositionIDPadding.from_pretrained(
    SAVE_TOKENIZER_PATH
)

In [None]:
encoded_sample_reloaded = reloaded_tokenizer(
    dataset["train"][0]["tokens"], is_split_into_words=True
)
assert encoded_sample == encoded_sample_reloaded

In [None]:
dataset

In [None]:
encoded_dataset = dataset.map(
    lambda batch: pretrained_fast_tokenizer(
        batch["tokens"], is_split_into_words=True, return_token_type_ids=False
    ),
    batched=True,
    remove_columns=[
        "timeline",
        "timestamps",
        "year_of_birth",
        "sex",
        "ethnicity",
        "tokens",
    ],
)
encoded_dataset

In [None]:
for key, value in encoded_dataset["train"][0].items():
    print(key, value[:10] if type(value) == list else value)

In [None]:
encoded_dataset.save_to_disk(SAVE_ENCODED_DATASET_PATH)