In [None]:
from collections import Counter
from pathlib import Path
from random import Random

from datasets import Dataset

In [None]:
from foresight.tokenizers.simple_map_tokenizer_v2 import SimpleMapTokenizer

## Dummy Data

In [None]:
NUM_TIMELINES = 1000
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
NUM_TIMESTEPS = 10
RANDOM_SEED = Random(23)
MAX_SKIP = 10
MAX_NUM_SAMPLES = 3

SEQUENCE_LENGTH = 12
SEPARATOR_TOKEN = "<SEP>"
PADDING_TOKEN = "<PAD>"

OUTPUT_DIR = Path.cwd() / "outputs"
SAVE_TOKENIZER_PATH = OUTPUT_DIR / "tokenizer"
SAVE_ENCODED_DATASET_PATH = OUTPUT_DIR / "encoded_dataset"

In [None]:
def get_samples():
    for _ in range(1000):
        start_idx = RANDOM_SEED.randint(0, (len(LETTERS) / 2))
        skip_odd = RANDOM_SEED.choice([True, False])
        num_samples = RANDOM_SEED.randint(1, MAX_NUM_SAMPLES)

        timeline = []
        for seq_idx in range(SEQUENCE_LENGTH):
            char_idx = start_idx + num_samples * seq_idx
            if skip_odd and char_idx % 2 == 1:
                timeline.append([])
            elif char_idx + num_samples >= len(LETTERS):
                timeline.append([])
            else:
                timeline.append(list(LETTERS[char_idx : char_idx + num_samples]))

        yield (
            {
                "timeline": timeline,
                "start_idx": start_idx,
                "skip_odd": skip_odd,
                "num_samples": num_samples,
            }
        )

In [None]:
dataset = Dataset.from_generator(get_samples)

In [None]:
dataset[0]

In [None]:
def batched_timeline_to_tokens(
    batched_samples: dict[str, list], separator: str
) -> dict[str, list]:
    batched_samples["tokens"] = [
        [
            timestep_value
            for timestep in timeline
            for timestep_value in [separator] + timestep
        ]
        for timeline in batched_samples["timeline"]
    ]
    return batched_samples

In [None]:
dataset = dataset.map(
    lambda batch: batched_timeline_to_tokens(batch, SEPARATOR_TOKEN), batched=True
)
dataset[0]["timeline"][:10], dataset[0]["tokens"][:10]

In [None]:
def batched_insert_static_feature_token(
    batched_samples: dict[str, list], key: str, insert_idx: int
) -> dict[str, list]:
    for idx, _ in enumerate(batched_samples["tokens"]):
        batched_samples["tokens"][idx].insert(
            insert_idx, f"{key}_{batched_samples[key][idx]}"
        )
    return batched_samples


dataset = dataset.map(
    lambda batch: batched_insert_static_feature_token(batch, "start_idx", insert_idx=0),
    batched=True,
)
dataset = dataset.map(
    lambda batch: batched_insert_static_feature_token(batch, "skip_odd", insert_idx=1),
    batched=True,
)
dataset = dataset.map(
    lambda batch: batched_insert_static_feature_token(
        batch, "num_samples", insert_idx=2
    ),
    batched=True,
)
dataset[0]["tokens"][:10]

## Add position IDs

In [None]:
def batched_add_position_ids(
    batched_samples: dict[str, list], separators: set[str]
) -> dict[str, list]:
    batched_samples["position_ids"] = []
    for tokens in batched_samples["tokens"]:
        position_ids = []
        cnt = 0
        for token in tokens:
            if token in separators:
                cnt += 1
            position_ids.append(cnt)
        batched_samples["position_ids"].append(position_ids)
    return batched_samples


dataset = dataset.map(
    lambda batch: batched_add_position_ids(batch, {SEPARATOR_TOKEN}), batched=True
)
[(dataset[0]["tokens"][idx], dataset[0]["position_ids"][idx]) for idx in range(10)]

In [None]:
dataset = dataset.train_test_split(test_size=0.2)
len(dataset["train"]), len(dataset["test"])

# Make tokenizer

In [None]:
token_count = Counter(
    token for tokens in dataset["train"]["tokens"] for token in tokens
)
token_count

In [None]:
tokenizer = SimpleMapTokenizer(token_count.keys())
encoded_sample = tokenizer.encode(dataset["train"][0]["tokens"])
for key, value in encoded_sample.items():
    print(key, value[:10])

In [None]:
tokenizer.save(SAVE_TOKENIZER_PATH)
reloaded_tokenizer = SimpleMapTokenizer.load(SAVE_TOKENIZER_PATH)
encoded_sample_reloaded = reloaded_tokenizer.encode(dataset["train"]["tokens"][0])
assert encoded_sample == encoded_sample_reloaded


In [None]:
encoded_dataset = dataset.map(lambda batch: tokenizer.batch_encode(batch), batched=True)

In [None]:
for key, value in encoded_dataset["train"][0].items():
    print(key, value[:10] if type(value) == list else value)


In [None]:
encoded_dataset.save_to_disk(SAVE_ENCODED_DATASET_PATH)