In [1]:
%%capture
!pip install --upgrade pip
!pip install transformers torch trl accelerate peft datasets bitsandbytes huggingface_hub

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [1]:
import torch
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path = "google/gemma-2-2b-it",
    torch_dtype=torch.bfloat16,
    attn_implementation='eager',
    device_map="cuda:0"
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [3]:
lora_config = LoraConfig(
        r=64,
        lora_alpha=256,
        lora_dropout=0,
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
        bias="none",
        task_type="CAUSAL_LM",
    )

In [4]:
peft_model = get_peft_model(model, lora_config)

In [28]:
def tool_call_format(system, tools, user, response):
    text = f"""<start_of_turn>user
                {system}
                {tools}
                {user}
                <end_of_turn>
                <start_of_turn>model
                {response}
                <end_of_turn>
            """
    return text

def general_format(system, user, response):
    text = f"""<start_of_turn>user
                {system}
                {user}
                <end_of_turn>
                <start_of_turn>model
                {response}
                <end_of_turn>
            """
    return text

In [29]:
EOS_TOKEN = tokenizer.eos_token
tool_call_system = "You are a helpful AI assistant that has to a set of tools listed between the <tools> xml tags that you may call to help the user. Only use them if the user query requires them. For each tool call return a json object with the name of the tool and its arguments surrounded by <tool_call> xml tags. If you decide not to use a tool respond normally and answer the user's query."
general_system = "You are a helpful AI assistant and you answer the user's questions clearly and correctly."

In [30]:
def formatting_prompts_func(examples):
    conversation = examples["conversations"]
    tools = examples["tools"]

    user = conversation[0]["value"]
    response = conversation[1]["value"]
    response_type = conversation[1]["from"]

    response = response if response_type == "gpt" else f"<tool_call>\n{response}\n</tool_call>"

    if tools.strip() in ("[]", ""):
        text = general_format(general_system, user, response) + EOS_TOKEN
    else:
        tools = f"<tools>\n{tools}\n</tools>"
        text = tool_call_format(tool_call_system, tools, user, response) + EOS_TOKEN
        
    return { "text" : text }

In [None]:
dataset = load_dataset("llamafactory/glaive_toolcall_en", split = "train")
dataset = dataset.map(formatting_prompts_func)
dataset = dataset.train_test_split(test_size=0.1, seed=20)

In [None]:
args = TrainingArguments(
  output_dir = "main",
  max_steps=500,
  per_device_train_batch_size = 1,
  weight_decay = 1e-3,
  warmup_steps = 50,
  logging_steps = 10,
  logging_dir="logs",
  save_strategy = "steps",
  eval_strategy= "steps",
  eval_steps = 100,
  save_steps = 100,
  learning_rate = 1e-4,
  bf16 = True,
  lr_scheduler_type = 'cosine',
  seed = 3407, 
)

trainer = SFTTrainer(
  model=peft_model,
  max_seq_length = 4096,
  dataset_text_field = "text",
  tokenizer=tokenizer,
  packing=False,
  args=args,
  train_dataset=dataset['train'],
  eval_dataset=dataset['test'],
)

In [None]:
trainer_stats = trainer.train()

In [None]:
gemma_prompt = """<start_of_turn>user
{}
<end_of_turn>
<start_of_turn>model
"""

tool_list = """\n<tools>
[{
  "name": "move_file",
  "description": "Move a file from one location to another",
  "parameters": {
    "type": "object",
    "properties": {
      "source": {
        "type": "string",
        "description": "The file that should be moved. It doesnt have to be a path just a worded explantion."
      },
      "target": {
        "type": "string",
        "description": "The destination of where the file should be moved to. It doesnt have to be a worded explantion."
      }
    },
    "required": ["source", "target"]
  }
},
{
  "name": "copy_file",
  "description": "Copy a file from one location to another",
  "parameters": {
    "type": "object",
    "properties": {
      "source": {
        "type": "string",
        "description": "The file that should be copied. It doesnt have to be a path just a worded explantion."
      },
      "target": {
        "type": "string",
        "description": "The destination of where the file should be copied to. It doesnt have to be a worded explantion."
      }
    },
    "required": ["source", "target"]
  }
},
{
  "name": "rename_file",
  "description": "Rename a file.",
  "parameters": {
    "type": "object",
    "properties": {
      "source": {
        "type": "string",
        "description": "The file that should be rename. It doesnt have to be a path just a worded explantion."
      },
      "new_name": {
        "type": "string",
        "description": "The exact name of the new file."
      }
    },
    "required": ["source", "new_name"]
  }
},
{
  "name": "goto_file",
  "description": "Go to a file.",
  "parameters": {
    "type": "object",
    "properties": {
      "target": {
        "type": "string",
        "description": "The target file. It doesnt have to be a path just a worded explantion."
      },
    "required": ["target"]
  }
},
{
  "name": "open_file",
  "description": "Open any type of file on the computer.",
  "parameters": {
    "type": "object",
    "properties": {
      "target": {
        "type": "string",
        "description": "The target file. It doesnt have to be a path just a worded explantion."
      },
    "required": ["target"]
  }
},
{
  "name": "delete_file",
  "description": "Delete a file.",
  "parameters": {
    "type": "object",
    "properties": {
      "target": {
        "type": "string",
        "description": "The target file. It doesnt have to be a path just a worded explantion."
      },
    "required": ["target"]
  }
},
{
  "name": "local_search",
  "description": "Retrieve relevant information from local knowledge base to enhance response accuracy for the user's specific query.",
  "parameters": {
    "type": "object",
    "properties": {
      "query": {
        "type": "string",
        "description": "The user's query. The thing they are looking for."
      },
    "required": ["query"]
  }
}]
</tools>"""

prompt = "\nOpen the apple image."


inputs = tokenizer(
[
    gemma_prompt.format(
        tool_call_system + tool_list + prompt, 
        "",
    )
], return_tensors = "pt").to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)
_ = peft_model.generate(**inputs, streamer = text_streamer, max_new_tokens = 500)

In [None]:
peft_model.save_pretrained("lora_model") 
tokenizer.save_pretrained("lora_model")

In [None]:
merged_model = peft_model.merge_and_unload()

In [None]:
merged_model.save_pretrained("./merged_model")
tokenizer.save_pretrained("./merged_model")

In [None]:
from huggingface_hub import HfApi

api = HfApi()
repo_name = "trishonc/gemma-2-2b-it-tool-use"

api.upload_folder(
    folder_path="./merged_model",
    repo_id=repo_name,
    repo_type="model",
)