In [None]:
import os

from datasets import load_dataset
from huggingface_hub import login
from omegaconf import OmegaConf

from lema.builders import (
    build_model,
    build_peft_model,
    build_tokenizer,
    build_trainer,
)
from lema.core.types import TrainingConfig
from lema.datasets.ultrachat_200k import apply_chat_template
from lema.utils.torch_utils import device_cleanup, limit_per_process_memory

In [None]:
access_token = os.environ.get("HF_TOKEN")
login(token=access_token)

In [None]:
config_filename = "../configs/lema/zephyr.7b.sft.yaml"
base_config = OmegaConf.structured(TrainingConfig)
file_config = TrainingConfig.from_yaml(config_filename)
config = OmegaConf.merge(base_config, file_config)
config: TrainingConfig = OmegaConf.to_object(config)
print(config.training)

In [None]:
limit_per_process_memory()
device_cleanup()

In [None]:
tokenizer = build_tokenizer(config)
tokenizer

In [None]:
# Set reasonable default for models without max length
if tokenizer.model_max_length > 100_000:
    tokenizer.model_max_length = 2048

print("tokenizer.model_max_length", tokenizer.model_max_length)
print("tokenizer pad_token/eos_token", tokenizer.pad_token, tokenizer.eos_token)
print("tokenizer.padding_side", tokenizer.padding_side)

In [None]:
chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"  # noqa
if chat_template is not None:
    tokenizer.chat_template = chat_template

In [None]:
model = build_model(config)

In [None]:
tokenizer.additional_special_tokens  #'<|assistant|>', <|system|>
tokenizer.encode("<|system|>")  # We already wrap <bos> and <eos>
# in the chat template
# add_special_tokens=
tokenizer.encode("|system|")

In [None]:
print(config.training.use_peft)
print(config.training.enable_gradient_checkpointing)

In [None]:
if config.training.use_peft:
    model = build_peft_model(model, config)

if config.training.enable_gradient_checkpointing:
    model.enable_input_require_grads()

In [None]:
dataset = load_dataset(config.data.dataset_name, split=config.data.split)
print(len(dataset))

In [None]:
# preprocessing_fn = build_prompt_generation_fn(preprocessing_function_name, tokenizer)
# dataset = dataset.map(preprocessing_fn, batched=True, **kwargs)
# dataset = dataset.map(preprocessing_fn, batched=True)

# # For ChatML we need to add special tokens and resize the embedding layer
# if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path: # noqa
#     model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) # noqa
#     model, tokenizer = setup_chat_format(model, tokenizer)
#     model_kwargs = None

In [None]:
dataset = dataset.map(
    apply_chat_template,
    fn_kwargs={
        "tokenizer": tokenizer,
        "task": "sft",
        # "Whether to automatically insert an empty system message
        # as the first message if `system` is mentioned
        # in the chat template."
        "auto_insert_empty_system_msg": True,
    },
    # num_proc=data_args.preprocessing_num_workers,
    num_proc=6,
    # remove_columns=column_names,
    remove_columns=[],
    desc="Applying chat template",
)

In [None]:
# dataset[0]

In [None]:
trainer_cls = build_trainer(config)
trainer_cls

In [None]:
trainer = trainer_cls(
    model=model,
    tokenizer=tokenizer,
    args=config.training.to_hf(),
    train_dataset=dataset,
    **config.data.trainer_kwargs,
)

In [None]:
tokenizer.max_len_single_sentence

In [None]:
trainer.train()

In [None]:
# # Save final checkpoint & training state
# trainer.save_state()

# save_model(
#     config=config,
#     trainer=trainer,
# )