In [None]:
import copy
import logging
import math
import time
from pathlib import Path
from typing import Any

import datasets
from torch.utils.data import DataLoader
from transformers import Trainer, TrainingArguments

from foresight.datasets.data_collator_v2 import (
    DataCollatorForLanguageModelingMaskStaticVariables,
)
from foresight.models.foresight_llama import (
    ForesightLlamaConfig,
    ForesightLlamaForCausalLM,
)
from foresight.tokenizers import PreTrainedTokenizerFastWithPositionIDPadding

In [None]:
log = logging.getLogger()
log.handlers.clear()
log.addHandler(logging.StreamHandler())
log.setLevel(logging.INFO)

In [None]:
OUTPUT_DIR = Path.cwd() / "outputs"
SAVE_TOKENIZER_PATH = OUTPUT_DIR / "tokenizer"
SAVE_ENCODED_DATASET_PATH = OUTPUT_DIR / "encoded_dataset"
MODEL_LOGS_DIR = OUTPUT_DIR / "model_logs" / time.strftime("%Y_%m_%d_%H_%M_%S")
FINAL_MODEL_DIR = MODEL_LOGS_DIR / "final_model"
MODEL_LOGS_DIR.mkdir(parents=True, exist_ok=True)

NUM_STATIC_VARIABLES = 4

In [None]:
encoded_dataset = datasets.load_from_disk(SAVE_ENCODED_DATASET_PATH)
encoded_dataset

In [None]:
tokenizer = PreTrainedTokenizerFastWithPositionIDPadding.from_pretrained(
    SAVE_TOKENIZER_PATH
)
training_data_collator = DataCollatorForLanguageModelingMaskStaticVariables(
    tokenizer=tokenizer, mlm=False, num_static_variables=NUM_STATIC_VARIABLES
)

# Create Model

In [None]:
def get_model(
    params: dict[str, Any],
    tokenizer: PreTrainedTokenizerFastWithPositionIDPadding,
    max_sequence_length: int,
):
    print("get_model", params)
    if params is None:
        params = {}

    hidden_size = params.get("hidden_size", 512)
    # From OLMo paper
    intermediate_size = hidden_size / (8 / 3)
    intermediate_size = round(intermediate_size / 100) * 100

    config = ForesightLlamaConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=hidden_size,
        intermediate_size=intermediate_size,
        num_hidden_layers=params.get("num_attention_heads", 4),
        num_attention_heads=params.get(
            "num_attention_heads", 4
        ),  # TODO: Check if to tie these
        num_key_value_heads=params.get(
            "num_attention_heads", 4
        ),  # TODO: Use multi-head attention
        max_position_embeddings=max_sequence_length,
        use_cache=False,  # TODO: Figure out how to use cache
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=-100,  # We don't use BOS token
        sep_token_id=tokenizer.sep_token_id,
        eos_token_id=tokenizer.eos_token_id,
        tie_word_embeddings=False,
        rope_theta=10000.0,  # TODO: Read up on ROPE
        rope_scaling=None,
        attention_bias=False,
        attention_dropout=params.get("attention_dropout", 0.0),  # Config
    )

    return ForesightLlamaForCausalLM(config)


max_sequence_length = math.ceil(
    max(len(sample["input_ids"]) for sample in encoded_dataset["train"]) * 1.2
)

get_model_lambda = lambda params: get_model(  # noqa : E731
    params, tokenizer, max_sequence_length
)
trial_model = get_model_lambda(None)

In [None]:
2 / 1

In [None]:
10000 ** (2)

In [None]:
sum(p.numel() for p in trial_model.parameters() if p.requires_grad)

In [None]:
trial_dataset = DataLoader(
    encoded_dataset["train"],
    batch_size=8,
    shuffle=False,
    collate_fn=training_data_collator,
)
batch = next(iter(trial_dataset))
trial_model(**{k: v for k, v in batch.items()}).logits.shape

# Trainer

In [None]:
gpus_per_trial = 1
training_args = TrainingArguments(
    output_dir=MODEL_LOGS_DIR,  # output directory
    no_cuda=gpus_per_trial <= 0,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    num_train_epochs=5,
    per_device_eval_batch_size=32,
    per_device_train_batch_size=32,  # config
    warmup_ratio=0.1,  # config
    weight_decay=0.1,  # config
    logging_dir="./logs",
    skip_memory_metrics=True,
    report_to="none",
    disable_tqdm=True,
)

In [None]:
def compute_objective(metrics):
    metrics = copy.deepcopy(metrics)
    return metrics.pop("eval_loss")

In [None]:
trainer = Trainer(
    model_init=get_model_lambda,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    data_collator=training_data_collator,
)

In [None]:
model = trainer.train()