In [None]:
from trl import SFTTrainer
import transformers
from transformers import AutoModelForSeq2SeqLM, set_seed, AutoTokenizer, TrainingArguments
from accelerate import Accelerator
from typing import Dict
import torch
from peft import LoraConfig
from transformers.trainer_utils import get_last_checkpoint
import logging
from transformers import DataCollatorForSeq2Seq
import evaluate
import numpy as np


logger = logging.getLogger(__name__)

In [None]:
### Dataset processing
from datasets import DatasetDict, concatenate_datasets, load_dataset

In [None]:
from huggingface_hub import login
login(token="")

In [None]:
MODEL_NAME = ""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
metric = evaluate.load("sacrebleu")
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # In case the model returns more than the prediction logits
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100s in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": result["score"]}

In [None]:
def mix_datasets(dataset_mixer, shuffle=True, seed=42, test_percentage=0.1):
    '''
    Example format for dataset_nmixer:
    dataset_mixer = {
            "dataset1": 1, # dataset_name: proportion
            # "dataset1": 0.3,
            # "dataset1": 0.2,
                }
    '''
    raw_train_datasets = []
    raw_test_datasets = []
    new_dataset = DatasetDict()
    for key, value in dataset_mixer.items():
        dataset = load_dataset(key)
        if "train" in dataset:
            train_dataset = dataset["train"]
            # train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if col not in COLUMNS_TO_KEEP])
            raw_train_datasets.append((train_dataset, value))

        if "test" in dataset:
            test_dataset = dataset["test"]
            # test_dataset = test_dataset.remove_columns([col for col in test_dataset.column_names if col not in COLUMNS_TO_KEEP])
            raw_test_datasets.append(test_dataset)
    train_subsets = []
    for (dataset, frac) in raw_train_datasets:
        train_subset = dataset.select(range(int(len(dataset)*frac)))
        train_subsets.append(train_subset)
    if shuffle:
        train_dataset = concatenate_datasets(train_subsets).shuffle(seed=seed)
    else:
        train_dataset = concatenate_datasets(train_subsets)
    if len(raw_test_datasets) > 0:
        test_dataset = concatenate_datasets(raw_test_datasets).shuffle(seed=seed)
    else:
        test_dataset = None
    
    new_dataset['train'] = train_dataset

    if test_dataset is None:
        new_dataset = new_dataset['train'].train_test_split(test_size=test_percentage)
    else:
        new_dataset['test'] = test_dataset
    
    return new_dataset


In [None]:
max_length = 128

def preprocess_function(examples):
    inputs = [ex for ex in examples['question']]
    targets = [ex for ex in examples['queries']]
    model_inputs = tokenizer(
        inputs, text_target=targets, max_length=max_length, truncation=True
    )
    return model_inputs

In [None]:
###Example###
dataset_mixer = {
            "HuggingFaceH4/no_robots": 1, # dataset_name: proportion
            # "dataset1": 0.3,
            # "dataset1": 0.2,
                }
tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME
    )
# tokenizer.model_max_length = 2048
# tokenizer.truncation_side = "left"

dataset = mix_datasets(dataset_mixer)

tokenized_datasets = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
)

num_raw_train_samples = len(dataset["train"])


In [None]:
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["test"]

In [None]:
##model kwargs

model_kwargs = dict(
        trust_remote_code=True,
        use_flash_attention_2=False,
        torch_dtype="auto",
        use_cache=False,
    )
model = MODEL_NAME
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
output_dir = "output_dir"

In [None]:
training_args = TrainingArguments(
    output_dir=output_dir,
    dataloader_drop_last=True,
    evaluation_strategy="epoch",
    save_total_limit=5,
    save_steps=100,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=2e-5,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    num_train_epochs=5,
    fp16=True,
    bf16=False,
    report_to="none",
    ddp_find_unused_parameters=False,
    push_to_hub=True,
    predict_with_generate=True,
    save_total_limit=3,
    weight_decay=0.01,

    hub_model_id="thangvip/"+output_dir,
)

In [None]:
trainer = SFTTrainer(
        model=model,
        data_collator=data_collator,
        model_init_kwargs=model_kwargs,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        dataset_text_field="text",
        max_seq_length=2048,
        tokenizer=tokenizer,
        packing=True,
        seq_length=2048,
        eos_token_id=tokenizer.eos_token_id,
        infinite=True,
        # dataset_kwargs=dataset_kwargs,
    )

In [None]:
trainer.evaluate(max_length=max_length)

In [None]:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

In [None]:
trainer.evaluate(max_length=max_length)

In [None]:
kwargs = {
        "finetuned_from": MODEL_NAME,
        "dataset": list(dataset_mixer.keys()),
        "dataset_tags": list(dataset_mixer.keys()),
    }

if trainer.accelerator.is_main_process:
    trainer.create_model_card(**kwargs)
    # Restore k,v cache for fast inference
    trainer.model.config.use_cache = True
    trainer.model.config.save_pretrained(output_dir)

do_eval = False

if do_eval:
    logger.info("*** Evaluate ***")
    metrics = trainer.evaluate()
    metrics["eval_samples"] = len(eval_dataset)
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

push_to_hub = True
if push_to_hub is True:
    logger.info("Pushing to hub...")
    trainer.push_to_hub(**kwargs)
logger.info("*** Training complete ***")
