In [None]:
import os
import datasets
from dataclasses import dataclass
from transformers import AutoTokenizer, PreTrainedTokenizerBase

In [None]:
fw_tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained('thng292/fineweb-vi-en-tokenizer')

In [None]:
fw_tokenizer.pad_token

In [None]:
data_en: datasets.DatasetDict = datasets.load_dataset("thng292/fineweb-subset-1M", "subset-en")
data_vi: datasets.DatasetDict = datasets.load_dataset("thng292/fineweb-subset-1M", "subset-vi")

train = datasets.interleave_datasets([data_en["train"], data_vi["train"]])
test = datasets.interleave_datasets([data_en["test"], data_vi["test"]])

In [None]:
max_len = 4096

In [None]:
import bisect
import gc


def search_for_fit(numbers: list[tuple[int, int]], capacity: int) -> int:
    r"""Find the index of largest number that fits into the knapsack with the given capacity."""
    index = bisect.bisect(numbers, capacity, key=lambda a: a[1])
    return index - 1


def greedy_knapsack(
    numbers: list[tuple[int, int]],  # should be `list(enumerate(lengths + 1))`
    capacity: int,
) -> list[list[tuple[int, int]]]:
    r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
    numbers.sort(
        key=lambda a: a[1]
    )  # sort numbers in ascending order for binary search
    knapsacks = []

    while numbers:
        current_knapsack = []
        remaining_capacity = capacity

        while True:
            index = search_for_fit(numbers, remaining_capacity)
            if index == -1:
                break  # no more numbers fit in this knapsack

            remaining_capacity -= numbers[index][1]  # update the remaining capacity
            current_knapsack.append(numbers.pop(index))  # add the number to knapsack

        knapsacks.append(current_knapsack)

    return knapsacks


@dataclass
class TokenizeAndPack:
    tokenizer: PreTrainedTokenizerBase
    max_length: int

    def __call__(self, row: dict[str, list]):
        texts: list[str] = row["text"]
        tokens = self.tokenizer(texts)

        tokenss = [
            input_ids
            for input_ids in tokens.data["input_ids"]
            if len(input_ids) < self.max_length
        ]
        lengths = [len(token) for token in tokenss]

        input_idss: list[list[int]] = []
        attention_masks: list[list[int]] = []
        # pad_len = max(lengths) + 1

        to_be_packed = greedy_knapsack(
            list(enumerate(length + 1 for length in lengths)), capacity=self.max_length
        )
        for packed_sample_ids in to_be_packed:
            input_ids = []
            attention_mask = []
            seq_count = 0
            for index, length in packed_sample_ids:
                input_ids.extend(tokenss[index] + [self.tokenizer.eos_token_id])
                attention_mask.extend([seq_count] * length)
                assert len(input_ids) == len(attention_mask)
                seq_count += 1
            # to_pad = pad_len - len(input_ids)
            input_idss.append(input_ids)  # + [tokenizer.pad_token_id] * to_pad
            attention_masks.append(attention_mask)  # + [seq_count] * to_pad
        gc.collect()

        return {
            "input_ids": input_idss,
            "attention_mask": attention_masks,
            "length": [len(input_ids) for input_ids in input_idss],
        }

In [None]:
tokenized_train = train.map(
    TokenizeAndPack(fw_tokenizer, max_length=max_len),
    fn_kwargs={"max_length": max_len},
    num_proc=os.cpu_count() or 1,
    batched=True,
    remove_columns=train.column_names
)

In [None]:
tokenized_test = test.map(
    TokenizeAndPack(fw_tokenizer, max_length=max_len), 
    fn_kwargs={"max_length": max_len},
    num_proc=os.cpu_count() or 1, 
    batched=True,
    remove_columns=test.column_names
)

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("HF_TOKEN")

In [None]:
experiment_1 = datasets.DatasetDict()
experiment_1["train"] = tokenized_train 
experiment_1["test"] = tokenized_test
experiment_1.push_to_hub("fw-experiment-1-tokenized-packed", token=secret_value_0)