# How to force models to call a tool

In [8]:
import os
import getpass
from langchain_core.tools import tool
from langchain.chat_models import init_chat_model

In [3]:
if not os.environ.get("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key")

Enter API key ········


In [5]:
@tool
def add(a: int, b: int) -> int:
    """Adds a and b."""
    return a + b


@tool
def multiply(a: int, b: int) -> int:
    """Multiplies a and b."""
    return a * b

In [6]:
tools = [add, multiply]

In [9]:
llm = init_chat_model("llama3-8b-8192", model_provider="groq")

In [10]:
llm_forced_to_multiply = llm.bind_tools(tools, tool_choice="multiply")

In [11]:
llm_forced_to_multiply.invoke("what is 2 + 4")

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_a5ca', 'function': {'arguments': '{"a":2,"b":4}', 'name': 'multiply'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 72, 'prompt_tokens': 1892, 'total_tokens': 1964, 'completion_time': 0.06, 'prompt_time': 0.23908216, 'queue_time': -0.501590989, 'total_time': 0.29908216}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_dadc9d6142', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-3c027668-4521-4c84-9e60-42887c522a38-0', tool_calls=[{'name': 'multiply', 'args': {'a': 2, 'b': 4}, 'id': 'call_a5ca', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1892, 'output_tokens': 72, 'total_tokens': 1964})

In [13]:
# We can also just force our tool to select at least one of our tools by passing in the "any" (or "required" which is OpenAI specific) 
# keyword to the tool_choice parameter.
llm_forced_to_use_tool = llm.bind_tools(tools, tool_choice="any")
llm_forced_to_use_tool.invoke("what is 2 + 4?")

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_m83h', 'function': {'arguments': '{"a":2,"b":4}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 79, 'prompt_tokens': 936, 'total_tokens': 1015, 'completion_time': 0.065833333, 'prompt_time': 0.145963313, 'queue_time': 0.974110807, 'total_time': 0.211796646}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_179b0f92c9', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-5e63684d-4e7e-4377-b767-395df99a7965-0', tool_calls=[{'name': 'add', 'args': {'a': 2, 'b': 4}, 'id': 'call_m83h', 'type': 'tool_call'}], usage_metadata={'input_tokens': 936, 'output_tokens': 79, 'total_tokens': 1015})