In [None]:
import os

from huggingface_hub import login
from omegaconf import OmegaConf

from oumi.builders import (
    build_dataset_mixture,
    build_model,
    build_peft_model,
    build_tokenizer,
    build_trainer,
)
from oumi.core.types import DatasetSplit, TrainingConfig
from oumi.utils.saver import save_model

%load_ext autoreload
%autoreload 2

In [None]:
access_token = os.environ.get("HF_TOKEN")
login(token=access_token)

In [None]:
config_filename = "../configs/oumi/zephyr.7b/sft/config_qlora.yaml"
base_config = OmegaConf.structured(TrainingConfig)
file_config = TrainingConfig.from_yaml(config_filename)
config = OmegaConf.merge(base_config, file_config)
config: TrainingConfig = OmegaConf.to_object(config)
print(config.training)
print(config.peft)

In [None]:
config.training.max_steps = 2  # debug

In [None]:
tokenizer = build_tokenizer(config.model)
tokenizer

In [None]:
print("tokenizer.model_max_length", tokenizer.model_max_length)
print("tokenizer pad_token/eos_token", tokenizer.pad_token, tokenizer.eos_token)
print("tokenizer.padding_side", tokenizer.padding_side)
print("tokenizer.chat_template", tokenizer.chat_template)

In [None]:
# Load data & preprocessing
dataset = build_dataset_mixture(config, tokenizer, DatasetSplit.TRAIN)


if not config.data.train.stream:
    import numpy as np  # hack to subsample

    print(len(dataset))
    np.random.seed(1234)
    ridx = np.random.choice(len(dataset), 1024, replace=False)
    dataset = dataset.select(ridx)
    print(len(dataset))

dataset

In [None]:
# Are we supporting PEFT?
use_peft = config.training.use_peft and config.peft
print("use_peft", use_peft)

# Build model.
model = build_model(
    model_params=config.model, peft_params=config.peft if use_peft else None
)

if use_peft:
    model = build_peft_model(
        model, config.training.enable_gradient_checkpointing, config.peft
    )

# Enable gradients for input embeddings
if config.training.enable_gradient_checkpointing:
    model.enable_input_require_grads()

In [None]:
trainer_cls = build_trainer(config.training.trainer_type)

# Train model
create_trainer_fn = build_trainer(config.training.trainer_type)

trainer = create_trainer_fn(
    model=model,
    tokenizer=tokenizer,
    args=config.training.to_hf(),
    train_dataset=dataset,
    **config.training.trainer_kwargs,
)

In [None]:
trainer.train()

In [None]:
# Save final checkpoint & training state
trainer.save_state()

save_model(
    config=config,
    trainer=trainer,
)

In [None]:
## See again:
# UserWarning: You passed a tokenizer with `padding_side` not equal to `right`
# to the SFTTrainer. This might lead to some unexpected behaviour due to overflow
# issues when training a model in half-precision. You might
# consider adding `tokenizer.padding_side = 'right'` to your code.

# TODO - update our code base if we use optimum (build_model)
# Using `disable_exllama` is deprecated and will be removed in version 4.37.
# Use `use_exllama` instead and specify the version with `exllama_config`.
# The value of `use_exllama` will be overwritten by `disable_exllama` passed
# in `GPTQConfig` or stored in your config file. # noqa
# WARNING:auto_gptq.nn_modules.qlinear.qlinear_cuda:CUDA extension not installed.
# # TODO update in main repo # noqa


# TODO Consider adding special tokens like '<|assistant|>', '<|system|>'
# via tokenizer.additional_special_tokens -- need to check Mistral

# from alignment team:
# tokenizer.encode("<|system|>")  # We already wrap <bos> and <eos>
# # in the chat template

# Future TODO.
# # For ChatML we need to add special tokens and resize the embedding layer
# if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path: # noqa
#     model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) # noqa
#     model, tokenizer = setup_chat_format(model, tokenizer)
#     model_kwargs = None


# if tokenizer.model_max_length > 100_000: # shall this condition be checked for diff.
#  Zephyr models? Now is not.

# tokenizer.all_special_tokens
# print(tokenizer.encode("|system|"))
# dataset[0]["text"]