In [None]:
import time
from pathlib import Path

import datasets
from transformers import DataCollatorWithPadding

from foresight.datasets.data_collator_v2 import (
    DataCollatorForLanguageModelingMaskStaticVariables,
)
from foresight.models.foresight_llama import ForesightLlamaForCausalLM
from foresight.tokenizers import PreTrainedTokenizerFastWithPositionIDPadding

In [None]:
OUTPUT_DIR = Path.cwd() / "outputs"
SAVE_TOKENIZER_PATH = OUTPUT_DIR / "tokenizer"
SAVE_ENCODED_DATASET_PATH = OUTPUT_DIR / "encoded_dataset"
MODEL_LOGS_DIR = OUTPUT_DIR / "model_logs" / time.strftime("%Y_%m_%d_%H_%M_%S")
FINAL_MODEL_DIR = MODEL_LOGS_DIR / "final_model"
MODEL_LOGS_DIR.mkdir(parents=True, exist_ok=True)

NUM_STATIC_VARIABLES = 4

In [None]:
encoded_dataset = datasets.load_from_disk(SAVE_ENCODED_DATASET_PATH)
encoded_dataset

In [None]:
model = ForesightLlamaForCausalLM.from_pretrained(FINAL_MODEL_DIR)
model.to("cuda")

In [None]:
tokenizer = PreTrainedTokenizerFastWithPositionIDPadding.from_pretrained(
    SAVE_TOKENIZER_PATH
)
training_data_collator = DataCollatorForLanguageModelingMaskStaticVariables(
    tokenizer=tokenizer, mlm=False, num_static_variables=NUM_STATIC_VARIABLES
)

In [None]:
tokenizer.padding_side = "left"
inference_data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
batch = inference_data_collator(
    encoded_dataset["test"][:2],
)
batch = {k: v.to("cuda") for k, v in batch.items()}

In [None]:
output_ids = model.generate(**batch).cpu()
output_ids

In [None]:
output_tokens = [
    [
        token
        for token in tokenizer.convert_ids_to_tokens(ids)
        if token != tokenizer.pad_token
    ]
    for ids in output_ids
]
output_tokens