### Improve LLM for custom tool calling with a SQL Database

In [2]:
!pip install -qqq jmespath
!pip install -qqq git+https://github.com/huggingface/transformers.git
!pip install -qqq git+https://github.com/huggingface/trl.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.0/521.0 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
sentence-transformers 5.1.2 requires transformers<5.0.0,>=4.41.0, but you have transformers 5.0.0.dev0 which is incompatible.[0m[31m
[0m  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for trl (pyproject.toml) ... [?25l[?25hdone


Tools: Custom Knowledge Base
Hotel booking agent
In this tutorial we are going to build and train a custom agent for managing hotel bookings. We'll set up a postgres database and expose tools to the agent, and then use the GRPO algorithm from TRL to improve the agents performance.

Start a postgres server which we'll use as a demo knowledge base.

In [3]:
# Install postgresql to run a DB server on colab
%%shell

sudo apt-get -y -qq update > /dev/null 2>&1
sudo apt-get -y -qq install postgresql > /dev/null 2>&1



In [4]:
# Start the postgresql server.
!sudo service postgresql start

# Check that postgres is running
!sudo lsof -i :5432

 * Starting PostgreSQL 14 database server
   ...done.
COMMAND   PID     USER   FD   TYPE DEVICE SIZE/OFF NODE NAME
postgres 3473 postgres    5u  IPv6  51111      0t0  TCP localhost:postgresql (LISTEN)
postgres 3473 postgres    6u  IPv4  51112      0t0  TCP localhost:postgresql (LISTEN)


We have a postgres server, so let's setup a database.

In [5]:
# Create a dedicated database and a user to access our DB securely
%%shell

sudo -u postgres psql << EOF
CREATE USER hotel_user WITH PASSWORD 'my-password';
CREATE DATABASE hotel_db;
GRANT ALL PRIVILEGES ON DATABASE hotel_db TO hotel_user;
ALTER DATABASE hotel_db OWNER TO hotel_user;
EOF

CREATE ROLE
CREATE DATABASE
GRANT
ALTER DATABASE




Now we're going to populate out database with dummy hotels.

In [6]:
# Connect to the database with the new user and create a hotels table.
%%shell

export PGPASSWORD=my-password
psql -h 127.0.0.1 -U hotel_user -d hotel_db --no-password << EOF
CREATE TABLE hotels(
   id            INTEGER NOT NULL PRIMARY KEY,
   name          VARCHAR NOT NULL,
   location      VARCHAR NOT NULL,
   price_tier    VARCHAR NOT NULL,
   checkin_date  DATE    NOT NULL,
   checkout_date DATE    NOT NULL,
   booked        BIT     NOT NULL
);
INSERT INTO hotels(id, name, location, price_tier, checkin_date, checkout_date, booked)
VALUES
  (1, 'Hilton Basel', 'Basel', 'Luxury', '2024-04-22', '2024-04-20', B'0'),
  (2, 'Marriott Zurich', 'Zurich', 'Upscale', '2024-04-14', '2024-04-21', B'0'),
  (3, 'Hyatt Regency Basel', 'Basel', 'Upper Upscale', '2024-04-02', '2024-04-20', B'0'),
  (4, 'Radisson Blu Lucerne', 'Lucerne', 'Midscale', '2024-04-24', '2024-04-05', B'0'),
  (5, 'Best Western Bern', 'Bern', 'Upper Midscale', '2024-04-23', '2024-04-01', B'0'),
  (6, 'InterContinental Geneva', 'Geneva', 'Luxury', '2024-04-23', '2024-04-28', B'0'),
  (7, 'Sheraton Zurich', 'Zurich', 'Upper Upscale', '2024-04-27', '2024-04-02', B'0'),
  (8, 'Holiday Inn Basel', 'Basel', 'Upper Midscale', '2024-04-24', '2024-04-09', B'0'),
  (9, 'Courtyard Zurich', 'Zurich', 'Upscale', '2024-04-03', '2024-04-13', B'0'),
  (10, 'Comfort Inn Bern', 'Bern', 'Midscale', '2024-04-04', '2024-04-16', B'0');
SELECT * from hotels;
EOF

CREATE TABLE
INSERT 0 10
 id |          name           | location |   price_tier   | checkin_date | checkout_date | booked 
----+-------------------------+----------+----------------+--------------+---------------+--------
  1 | Hilton Basel            | Basel    | Luxury         | 2024-04-22   | 2024-04-20    | 0
  2 | Marriott Zurich         | Zurich   | Upscale        | 2024-04-14   | 2024-04-21    | 0
  3 | Hyatt Regency Basel     | Basel    | Upper Upscale  | 2024-04-02   | 2024-04-20    | 0
  4 | Radisson Blu Lucerne    | Lucerne  | Midscale       | 2024-04-24   | 2024-04-05    | 0
  5 | Best Western Bern       | Bern     | Upper Midscale | 2024-04-23   | 2024-04-01    | 0
  6 | InterContinental Geneva | Geneva   | Luxury         | 2024-04-23   | 2024-04-28    | 0
  7 | Sheraton Zurich         | Zurich   | Upper Upscale  | 2024-04-27   | 2024-04-02    | 0
  8 | Holiday Inn Basel       | Basel    | Upper Midscale | 2024-04-24   | 2024-04-09    | 0
  9 | Courtyard Zurich        | Z



In [7]:
!pip install psycopg2-binary

import psycopg2
from typing import List, Optional

# Database connection parameters
DB_PARAMS = {
    "host": "127.0.0.1",
    "port": "5432",
    "database": "hotel_db",
    "user": "hotel_user",
    "password": "my-password"
}

def get_connection():
    return psycopg2.connect(**DB_PARAMS)

def search_hotels_by_name(name: str) -> List[tuple]:
    """
    Search for hotels based on name.

    Args:
        name: The name of the hotel.

    Returns:
        A list of hotels matching the name.
    """
    conn = get_connection()
    try:
        with conn.cursor() as cur:
            cur.execute("SELECT * FROM hotels WHERE name ILIKE %s", ('%' + name + '%',))
            return cur.fetchall()
    finally:
        conn.close()

def search_hotels_by_location(location: str) -> List[tuple]:
    """
    Search for hotels based on location.

    Args:
        location: The location of the hotel.

    Returns:
        A list of hotels matching the location.
    """
    conn = get_connection()
    try:
        with conn.cursor() as cur:
            cur.execute("SELECT * FROM hotels WHERE location ILIKE %s", ('%' + location + '%',))
            return cur.fetchall()
    finally:
        conn.close()

def book_hotel(hotel_id: str):
    """
    Book a hotel by its ID. If the hotel is successfully booked, returns None.

    Args:
        hotel_id: The ID of the hotel to book.
    """
    conn = get_connection()
    try:
        with conn.cursor() as cur:
            cur.execute("UPDATE hotels SET booked = B'1' WHERE id = %s", (hotel_id,))
        conn.commit()
    finally:
        conn.close()

def update_hotel(hotel_id: str, checkin_date: str, checkout_date: str) -> str:
    """
    Update a hotel's check-in and check-out dates by its ID.
    Returns a message indicating whether the hotel was successfully updated or not.

    Args:
        hotel_id: The ID of the hotel to update.
        checkin_date: The new check-in date of the hotel.
        checkout_date: The new check-out date of the hotel.
    """
    conn = get_connection()
    try:
        with conn.cursor() as cur:
            cur.execute(
                "UPDATE hotels SET checkin_date = %s, checkout_date = %s WHERE id = %s",
                (checkin_date, checkout_date, hotel_id)
            )
        conn.commit()
        return "Hotel updated successfully"
    except Exception as e:
        return f"Failed to update hotel: {e}"
    finally:
        conn.close()

def cancel_hotel(hotel_id: str):
    """
    Cancel a hotel by its ID.

    Args:
        hotel_id: The ID of the hotel to cancel.
    """
    conn = get_connection()
    try:
        with conn.cursor() as cur:
            cur.execute("UPDATE hotels SET booked = B'0' WHERE id = %s", (hotel_id,))
        conn.commit()
    finally:
        conn.close()

Collecting psycopg2-binary
  Downloading psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (4.9 kB)
Downloading psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (4.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.2/4.2 MB[0m [31m49.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: psycopg2-binary
Successfully installed psycopg2-binary-2.9.11


# Dataset

In [8]:
from datasets import Dataset

prompt = """
  You're a helpful hotel assistant. You handle hotel searching, booking and
  cancellations. When the user searches for a hotel, mention it's name, id,
  location and price tier. Always mention hotel ids while performing any
  searches. This is very important for any operations. For any bookings or
  cancellations, please provide the appropriate confirmation. Be sure to
  update checkin or checkout dates if mentioned by the user.
  Don't ask for confirmations from the user.
"""

dataset = Dataset.from_list(
    [
        {
            "prompt": [
                {"role": "system", "content": prompt},
                {"role": "user", "content": "Find hotels in Basel with Basel in it's name."},
            ],
            "answer": "Hilton Basel"
        },
        {
            "prompt": [
                {"role": "system", "content": prompt},
                {"role": "user", "content": "Can you book the Hilton Basel for me?"},
            ],
            "answer": "booked"
        },
        {
            "prompt": [
                {"role": "system", "content": prompt},
                {
                    "role": "user",
                    "content": "Oh wait, this is too expensive. Please cancel it and book the Hyatt Regency instead.",
                },
            ],
            "answer": "Hyatt Regency Basel"
        },
        {
            "prompt": [
                {"role": "system", "content": prompt},
                {
                    "role": "user",
                    "content": "My check in dates would be from April 10, 2024 to April 19, 2024.",
                },
            ],
            "answer": "updated"
        },
    ]
)

In [9]:
import wandb

# Initialize Weights & Biases for experiment tracking
wandb.init(project="transformer-fine-tuning", name="grpo-analysis")


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msulbha-jindal[0m ([33msulbha-jindal-amazon[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
# !wandb login --relogin

[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


# Train

In [11]:
def accuracy(completions, answer, **kwargs):
    predictions = [completion[-1]["content"] for completion in completions]
    rewards = [float(str(ans).lower() in pred.lower()) for pred, ans in zip(predictions, answer)]
    return rewards

In [12]:
from trl import GRPOTrainer, GRPOConfig
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(checkpoint, dtype="auto", device_map="auto")

# Load LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=16,
    target_modules="all-linear",
)
model = get_peft_model(model, lora_config)

# Training arguments
training_args = GRPOConfig(
    output_dir="hotel-agent",
    learning_rate=1e-4,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    num_generations=4,
    num_train_epochs=16,
    report_to="wandb",
    beta = 0.05
    # warmup_ratio=0.03,
    # lr_scheduler_type="constant",
)

trainer = GRPOTrainer(
    model=model,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=dataset,
    tools=[
        search_hotels_by_location,
        search_hotels_by_name,
        book_hotel,
        update_hotel,
        cancel_hotel,
    ],
    reward_funcs=accuracy,
)
trainer.train()

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/311 [00:00<?, ?it/s]



generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
`generation_config` default values have been modified to match model-specific defaults: {'temperature': 0.6, 'top_p': 0.95}. If this is not desired, please set these values explicitly.


Step,Training Loss
10,0.0
20,0.031539
30,0.00038
40,0.002807
50,0.000576
60,0.00044


TrainOutput(global_step=64, training_loss=0.005605631255093613, metrics={'train_runtime': 2840.6841, 'train_samples_per_second': 0.023, 'train_steps_per_second': 0.023, 'total_flos': 0.0, 'train_loss': 0.005605631255093613})

# Generate

In [17]:
trainer.model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen3ForCausalLM(
      (model): Qwen3Model(
        (embed_tokens): Embedding(151936, 1024)
        (layers): ModuleList(
          (0-27): 28 x Qwen3DecoderLayer(
            (self_attn): Qwen3Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=1024, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1024, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(
                (bas

In [25]:
import torch

messages = [
  {"role": "system", "content": prompt},
  {"role": "user", "content": "Find a hotel in Bern"}
]

tools = tools=[
        search_hotels_by_location,
        search_hotels_by_name,
        book_hotel,
        update_hotel,
        cancel_hotel,
]

inputs = tokenizer.apply_chat_template(messages
                                       , tools=tools
                                       , add_generation_prompt=True
                                       , return_dict=True
                                       , return_tensors="pt"
                                       , enable_thinking=False)
outputs = model.generate(**inputs.to(model.device), max_new_tokens=1028)
print(tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):]))

<tool_call>
{"name": "search_hotels_by_location", "arguments": {"location": "Bern"}}
</tool_call><|im_end|>


In [26]:
inputs = tokenizer.apply_chat_template(messages
                                       , tools=tools
                                       , add_generation_prompt=True
                                       , return_dict=True
                                       , return_tensors="pt"
                                       , enable_thinking=False)
outputs = trainer.model.generate(**inputs.to(model.device), max_new_tokens=1028)
print(tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):]))

<tool_call>
{"name": "search_hotels_by_location", "arguments": {"location": "Bern"}}
</tool_call><|im_end|>


In [27]:
import torch

messages = [
  {"role": "system", "content": prompt},
  {"role": "user", "content": "Book a hotel in Bern for April 4, 2024 thru April 16"}
]

tools = tools=[
        search_hotels_by_location,
        search_hotels_by_name,
        book_hotel,
        update_hotel,
        cancel_hotel,
]

inputs = tokenizer.apply_chat_template(messages
                                       , tools=tools
                                       , add_generation_prompt=True
                                       , return_dict=True
                                       , return_tensors="pt"
                                       , enable_thinking=False)
outputs = trainer.model.generate(**inputs.to(model.device), max_new_tokens=1028)
print(tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):]))

<tool_call>
{"name": "book_hotel", "arguments": {"hotel_id": "HOTEL_123", "checkin_date": "2024-04-04", "checkout_date": "2024-04-16"}}
</tool_call><|im_end|>
