# Mistral 7b function calling trainer

## Install Dependencies

In [None]:
!pip install --pre --upgrade bigdl-llm[all]
!pip install transformers==4.36.1
!pip install peft==0.5.0
!pip install datasets
!pip install accelerate==0.23.0
!pip install bitsandbytes scipy

## Import Libraries

In [None]:
import os
os.environ["HF_HOME"]="./data/cache"
import json
import torch
import transformers
from transformers import AutoTokenizer, BitsAndBytesConfig
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm.transformers.low_bit_linear import LowBitLinear
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training, LoraConfig
from bigdl.llm.utils.isa_checker import ISAChecker
from datasets import load_dataset

## Initialise Tokenizer & Model

In [None]:
model_id = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id, add_bos_token=True, trust_remote_code=True)
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="int4",  # nf4 not supported on cpu yet
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config,)

In [None]:
model = model.to('cpu')
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
model.enable_input_require_grads()

In [None]:
def get_all_linear_layers(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, LowBitLinear):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    return list(lora_module_names)

config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=get_all_linear_layers(model),
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)

## Prepare dataset

In [None]:
dataset = load_dataset("glaiveai/glaive-function-calling-v2", split='train')
dataset = dataset.train_test_split(test_size=0.1)
dataset

In [None]:
dataset['train'][0]

In [None]:
train_prompt_format = "<|im_start|>System\n{sys_msgs}<|im_end|>\n{conversation_msgs}"
eval_prompt_format = "[INST] {prompt} [/INST]"

def postprocess_conversation_msgs(msgs):
    # Assuming user input is a string
    lines = msgs.strip().split('\n')
    processed_lines = []

    for line in lines:
        if line.startswith('USER:'):
            user_message = line.replace('USER:', '').strip()
            processed_lines.append(f"<|im_start|>user\n{user_message}<|im_end|>")
        elif line.startswith('ASSISTANT: <functioncall>'):
            # Extracting function call information
            function_call = line.replace('ASSISTANT: <functioncall>', '').strip()
            function_call = function_call.replace('<|endoftext|>', '').strip()
            processed_lines.append(f"<|im_start|>function\n{function_call}<|im_end|>")
        elif line.startswith('FUNCTION RESPONSE:'):
            # Extracting function response information
            function_response = line.replace('FUNCTION RESPONSE:', '').strip()
            processed_lines.append(f"<|im_start|>function_response\n{function_response}<|im_end|>")
        elif line.startswith('ASSISTANT:'):
            # Extracting assistant response
            assistant_message = line.replace('ASSISTANT:', '').strip()
            assistant_message = assistant_message.replace('<|endoftext|>', '').strip()
            processed_lines.append(f"<|im_start|>assistant\n{assistant_message}<|im_end|>")

    return '\n'.join(processed_lines)
    
def format_and_tokenize_prompt(tokenizer, data, max_length=512):
    system_text, system_msgs = data['system'].split('SYSTEM:', 1)
    system_msgs = system_msgs.strip()
    # Printing the results
    conversation_msgs = postprocess_conversation_msgs(data['chat'])
    train_prompt = train_prompt_format.format(sys_msgs=system_msgs, conversation_msgs=conversation_msgs)
    tokenized_prompt = tokenizer(train_prompt, max_length=max_length, truncation=True)
    return tokenized_prompt
    
train_dataset = dataset['train'].map(lambda samples: format_and_tokenize_prompt(tokenizer, samples))
test_dataset = dataset['test'].map(lambda samples: format_and_tokenize_prompt(tokenizer, samples))

## Create Trainer

In [None]:
isa_checker = ISAChecker()
bf16_flag = isa_checker.check_avx512()
trainer_args = transformers.TrainingArguments(
    output_dir="./data/checkpoints",
    evaluation_strategy="epoch",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=2e-4,
    weight_decay=0,
    num_train_epochs=3,
    lr_scheduler_type='cosine',
    warmup_steps=0,
    logging_strategy="steps",
    logging_steps=1,
    save_strategy="epoch",
    save_total_limit=2,
    bf16=True,
    load_best_model_at_end=True,
    optim="adamw_hf",
    gradient_checkpointing=False
)

In [None]:
trainer = transformers.Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    args=trainer_args,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False

In [None]:
result = trainer.train()
print(result)