In [None]:
from trl import SFTTrainer, setup_chat_format
import transformers
from transformers import AutoModelForCausalLM, set_seed, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from accelerate import Accelerator
from typing import Dict
import torch
from peft import LoraConfig
from transformers.trainer_utils import get_last_checkpoint
import logging

logger = logging.getLogger(__name__)

In [None]:
### Dataset processing
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk

In [None]:
from huggingface_hub import login
login(token="")

In [None]:
COLUMNS_TO_KEEP = ["messages", "chosen", "rejected", "prompt", "completion", "label", "question", "answer"]

## For dataset processing

In [None]:
def mix_datasets(dataset_mixer, shuffle=True, seed=42, test_percentage=0.1):
    '''
    Example format for dataset_nmixer:
    dataset_mixer = {
            "dataset1": 1, # dataset_name: proportion
            # "dataset1": 0.3,
            # "dataset1": 0.2,
                }
    '''
    raw_train_datasets = []
    raw_test_datasets = []
    new_dataset = DatasetDict()
    for key, value in dataset_mixer.items():
        dataset = load_dataset(key)
        if "train" in dataset:
            train_dataset = dataset["train"]
            train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if col not in COLUMNS_TO_KEEP])
            raw_train_datasets.append((train_dataset, value))

        if "test" in dataset:
            test_dataset = dataset["test"]
            test_dataset = test_dataset.remove_columns([col for col in test_dataset.column_names if col not in COLUMNS_TO_KEEP])
            raw_test_datasets.append(test_dataset)
    train_subsets = []
    for (dataset, frac) in raw_train_datasets:
        train_subset = dataset.select(range(int(len(dataset)*frac)))
        train_subsets.append(train_subset)
    if shuffle:
        train_dataset = concatenate_datasets(train_subsets).shuffle(seed=seed)
    else:
        train_dataset = concatenate_datasets(train_subsets)
    if len(raw_test_datasets) > 0:
        test_dataset = concatenate_datasets(raw_test_datasets).shuffle(seed=seed)
    else:
        test_dataset = None
    
    new_dataset['train'] = train_dataset

    if test_dataset is None:
        new_dataset = new_dataset['train'].train_test_split(test_size=test_percentage)
    else:
        new_dataset['test'] = test_dataset
    
    return new_dataset


In [None]:
def apply_template(example, tokenizer, task="sft"):
    '''
    task can be: sft, rm, dpo, generation
    '''
    if task == "sft" or task == "generation":
        if "messages" in example:
            if example["messages"][0]["role"] == "system":
                messages = example["messages"]
            else:
                messages = [{"role":"system", "content": ""}] + example["messages"]
        elif "question" in example and "answer" in example:
            if "system_prompt" in example:
                messages = [{"role":"system", "content": example['system_prompt']},{"role": "user", "content": example["question"]}, {"role": "assistant", "content": example["answer"]}]
            else:
                messages = [{"role":"system", "content": ""},{"role": "user", "content": example["question"]}, {"role": "assistant", "content": example["answer"]}]
        example['text'] = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True if task == "generation" else False)

    elif task == "rm":
        if all(k in example.keys() for k in ("chosen", "rejected")):
            chosen_messages = example["chosen"]
            rejected_messages = example["rejected"]
            # We add an empty system message if there is none
            if chosen_messages[0]["role"] != "system":
                chosen_messages = [{"role":"system", "content": ""}] + chosen_messages
            
            if rejected_messages[0]["role"] != "system":
                rejected_messages = [{"role":"system", "content": ""}] + rejected_messages

            example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
            example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
        else:
            raise ValueError(
                f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
            )
    elif task == "dpo":
        if all(k in example.keys() for k in ("chosen", "rejected")):

            if isinstance(example['chosen'], list):
                prompt_message = example['chosen'][:-1]
                if prompt_message[0]["role"] != "system":
                    prompt_message = [{"role":"system", "content": ""}] + prompt_message
                chosen_message = example['chosen'][-1:]
                rejected_message = example['rejected'][-1:]
                example["text_chosen"] = tokenizer.apply_chat_template(chosen_message, tokenize=False)
                example["text_rejected"] = tokenizer.apply_chat_template(rejected_message, tokenize=False)
                example["text_prompt"] = tokenizer.apply_chat_template(prompt_message, tokenize=False)
            else:
                example['text_chosen'] = tokenizer.apply_chat_template(example['chosen'], tokenize=False)
                example['text_rejected'] = tokenizer.apply_chat_template(example['rejected'], tokenize=False)
                example['text_prompt'] = tokenizer.apply_chat_template(example['prompt'], tokenize=False)

        else:
            raise ValueError(
                f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
            )
    else:
        raise ValueError(f"Task {task} not recognized")
    
    return example

In [None]:
MODEL_NAME = ""

In [None]:
###Example###
dataset_mixer = {
            "HuggingFaceH4/no_robots": 1, # dataset_name: proportion
            # "dataset1": 0.3,
            # "dataset1": 0.2,
                }
tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME
    )
tokenizer.model_max_length = 2048
# tokenizer.truncation_side = "left"

dataset = mix_datasets(dataset_mixer)

dataset = dataset.map(
        apply_template,
        fn_kwargs={
            "tokenizer": tokenizer,
            "task": "sft",
        },
        num_proc=4,
        remove_columns=list(dataset["train"].features),
        desc="Applying chat template",
    )
num_raw_train_samples = len(dataset["train"])


In [None]:
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

## For model

In [None]:
load_in_4bit = True
load_in_8bit = False
quantization_config = None
if load_in_4bit:
    quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype="auto",
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=False,
        )
if load_in_8bit:
    quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
        )

In [None]:
##detect device map
def get_current_device() -> int:
    """Get the current device. For GPU we return the local process index to enable multiple GPU training."""
    return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"


def get_kbit_device_map() -> Dict[str, int] | None:
    """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
    return {"": get_current_device()} if torch.cuda.is_available() else None

In [None]:
use_peft = True

if use_peft:
    peft_config = LoraConfig(
            r=64,
            lora_alpha=128,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=['q_proj', 'k_proj', 'o_proj', "v_proj"],
            modules_to_save=['q_proj', 'k_proj', 'o_proj', "v_proj"],
        )
else:
    peft_config = None

In [None]:
##model kwargs

model_kwargs = dict(
        trust_remote_code=True,
        use_flash_attention_2=False,
        torch_dtype="auto",
        use_cache=False,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
model = MODEL_NAME

In [None]:
if "<|im_start|>" in tokenizer.chat_template:
    model = AutoModelForCausalLM.from_pretrained(model, **model_kwargs)
    model, tokenizer = setup_chat_format(model, tokenizer)
    model_kwargs = None


In [None]:
output_dir = "output_dir"

In [None]:
training_args = TrainingArguments(
    output_dir=output_dir,
    dataloader_drop_last=True,
    evaluation_strategy="steps",
    save_total_limit=5,
    save_steps=10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=1e-4,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    num_train_epochs=1,
    fp16=True,
    bf16=False,
    report_to="none",
    ddp_find_unused_parameters=False,
    push_to_hub=True,
    hub_model_id="thangvip/"+output_dir,
)

In [None]:
trainer = SFTTrainer(
        model=model,
        model_init_kwargs=model_kwargs,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        dataset_text_field="text",
        max_seq_length=2048,
        tokenizer=tokenizer,
        packing=True,
        peft_config=peft_config,
        seq_length=2048,
        eos_token_id=tokenizer.eos_token_id,
        infinite=True,
        # dataset_kwargs=dataset_kwargs,
    )

In [None]:
resume_from_last_checkpoint = False
checkpoint_to_resume = None
checkpoint = None
last_checkpoint = get_last_checkpoint()
if resume_from_last_checkpoint and last_checkpoint is not None:
    checkpoint = last_checkpoint
elif checkpoint_to_resume is not None:
    checkpoint = checkpoint_to_resume

train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

logger.info("*** Save model ***")
trainer.save_model(output_dir)
logger.info(f"Model saved to {output_dir}")

In [None]:
kwargs = {
        "finetuned_from": MODEL_NAME,
        "dataset": list(dataset_mixer.keys()),
        "dataset_tags": list(dataset_mixer.keys()),
    }

if trainer.accelerator.is_main_process:
    trainer.create_model_card(**kwargs)
    # Restore k,v cache for fast inference
    trainer.model.config.use_cache = True
    trainer.model.config.save_pretrained(output_dir)

do_eval = True

if do_eval:
    logger.info("*** Evaluate ***")
    metrics = trainer.evaluate()
    metrics["eval_samples"] = len(eval_dataset)
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

push_to_hub = True
if push_to_hub is True:
    logger.info("Pushing to hub...")
    trainer.push_to_hub(**kwargs)
logger.info("*** Training complete ***")
