In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def create_model(
    lm_model_name = "gemma-3-4b-it", 
    device = 'cuda' if torch.cuda.is_available() else 'cpu',
):
    tokenizer = AutoTokenizer.from_pretrained(lm_model_name)
    model = AutoModelForCausalLM.from_pretrained(
        lm_model_name,
        device_map="auto",
    ).to(device).eval()

    return model, tokenizer

In [3]:
def lm_template(
    text: str, 
    system_prompt: str = "You are a helpful assistant.", 
):
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_prompt}]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": text}]
        }
    ]

In [4]:
@torch.inference_mode()
def generate(
    prompt, 
    tokenizer, 
    model, 
    max_new_tokens: int = 256, 
    temperature: float = 1,
):
    inputs = tokenizer.apply_chat_template(
        prompt, 
        add_generation_prompt=True, 
        tokenize=True,
        return_dict=True, 
        return_tensors="pt",
    )
    inputs = {
        k: (
            v.to(model.device, dtype=model.dtype)
            if v.dtype.is_floating_point else v.to(model.device)
        )
        for k, v in inputs.items()
    }

    input_len = inputs["input_ids"].shape[-1]

    max_len = int(model.config.text_config.max_position_embeddings)
    if input_len > max_len:
        raise ValueError(
            f"Input length {input_len} exceeds maximum allowed length of {max_len} tokens."
        )

    generation = model.generate(
        **inputs, 
        max_new_tokens=max_new_tokens, 
        do_sample=True, 
        temperature=temperature, 
    )
    generation = generation[0][input_len:]

    response = tokenizer.decode(
        generation, 
        skip_special_tokens=True
    )

    return response

In [5]:
import re
import json
from typing import Dict, List, Optional

In [6]:
def calculator(
    expression: str
) -> float:
    try:
        if not re.match(r'^[\d\s\+\-\*/\(\)\.]+$', expression):
            return f"Error: Expression contains invalid characters"
        result = eval(expression)
        return result
    except Exception as e:
        return f"Calculation error: {str(e)}"

calculator_info = {
    "name": "calculator",
    "description": "Execute mathematical calculations. Use this tool when you need to perform precise numerical computations.",
    "parameters": {
        "type": "object",
        "properties": {
            "expression": {
                "type": "string",
                "description": "The mathematical expression to calculate, e.g., '123 * 456' or '(10 + 20) / 2'"
            }
        },
        "required": ["expression"]
    }, 
    "function": calculator
}


TOOL_DICT = {
    'calculator': calculator_info, 
}

In [7]:
SYSTEM_PROMPT = f"""
When you need to use a tool, respond in the following format:
<function_call>
{{"name": "tool_name", "arguments": {{"param_name": "param_value"}}}}
</function_call>

If you don't need to use a tool, respond directly to the question.

Important:
- Only call a function when it's necessary to answer the question
- Make sure the function call JSON is properly formatted
- After receiving the function result, provide a natural language response to the user
"""

def build_system_prompt_with_tools(
    tools: Dict
) -> str:
    tools_description = "You have access to the following tools:\n\n"
    for tool in tools.values():
        tools_description += f"Tool: {tool['name']}\n"
        tools_description += f"Description: {tool['description']}\n"
        tools_description += f"Parameters: {json.dumps(tool['parameters'], indent=2)}\n\n"
    
    system_prompt = f"""{tools_description}{SYSTEM_PROMPT}"""
    
    return system_prompt

In [8]:
def execute_tool(
    tool_name: str, 
    arguments: Dict, 
    tools: Dict, 
) -> str:
    expression = arguments.get("expression", "")
    try:
        result = tools[tool_name]['function'](expression)
        return str(result)
    except KeyError:
        return f"Unknown tool: {tool_name}"

def result2prompt(
    user_query: str, 
    tool_name: str, 
    arguments: str, 
    tool_result: str, 
) -> str:
    return f"The result of {tool_name}({arguments}) is: {tool_result}\n\nPlease provide a natural language response to the original question: {user_query}"

def parse_function_call(
    response: str
) -> Optional[Dict]:
    pattern = r'<function_call>\s*(\{.*?\})\s*</function_call>'
    match = re.search(pattern, response, re.DOTALL)
    
    if match:
        try:
            function_call = json.loads(match.group(1))
            return function_call
        except json.JSONDecodeError:
            return None
    return None

def filter_function_call(
    response: str
):
    return re.sub(r'<function_call>[\s\S]*?</function_call>', '', response)

In [9]:
def run_agent(
    user_query: str,
    model,
    tokenizer,
    tools,
    max_iterations: int = 3,
    max_new_tokens: int = 256,
    temperature: float = 1,
) -> str:
    system_prompt = build_system_prompt_with_tools(
        tools=tools, 
    )
    current_query = user_query
    for iteration in range(max_iterations):
        prompt = lm_template(
            text=current_query,
            system_prompt=system_prompt
        )
        response = generate(
            prompt=prompt,
            tokenizer=tokenizer,
            model=model,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
        )
        function_call = parse_function_call(
            response=response
        )
        if function_call:
            tool_name = function_call.get("name")
            arguments = function_call.get("arguments", {})
            tool_result = execute_tool(
                tool_name=tool_name, 
                arguments=arguments, 
                tools=tools, 
            )
            current_query = result2prompt(
                user_query=user_query, 
                tool_name=tool_name, 
                arguments=arguments, 
                tool_result=tool_result, 
            )
        else:
            return response
    
    return filter_function_call(response)

In [10]:
model, tokenizer = create_model(
    lm_model_name="gemma-3-4b-it"
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.47s/it]


In [11]:
test_cases = [
    "What is 123 multiplied by 456?",
    "What's the weather like today?", 
    "Calculate (987 + 654) * 321 / 2",
    "If I have 15 apples and buy 23 more, then give away 8, how many do I have?",
]

In [12]:
print("Start Demo!\n")

for i, query in enumerate(test_cases, 1):
    print(f"Test Case ({i}) {'=' * 50}")
    print(f"user input: {query}")

    response = run_agent(
        user_query=query,
        model=model,
        tokenizer=tokenizer,
        tools=TOOL_DICT,
    )
    print(f"Model response: {response}\n{'=' * 64}\n")

print("\nDemo Completed!")

Start Demo!

user input: What is 123 multiplied by 456?
Model response: 123 multiplied by 456 is 56088.


user input: What's the weather like today?
Model response: I do not have access to real-time weather information.

user input: Calculate (987 + 654) * 321 / 2
Model response: The result of the calculation is 263380.5.

user input: If I have 15 apples and buy 23 more, then give away 8, how many do I have?
Model response: If I have 15 apples and buy 23 more, then give away 8, how many do I have?

The result of calculator({'expression': '15 + 23 - 8'}) is: 30

You would have 30 apples.


Demo Completed!
