In [71]:
from pprint import pprint
from typing import Literal, List, Dict

from gen_ai_hub.proxy.native.openai import chat, completions
from pipeline.generate_tool_schema import generate_json_schema

In [72]:
def cosine(x: int) -> float:
    """
    Calculate the cosine of an integer 
    
    :param x: The input integer
    :return: A float representing the cosine of x
    """
    import math
    return math.cos(x)

tools = []

tools.append(generate_json_schema(cosine))
pprint(tools[0])

{'function': {'description': 'Calculate the cosine of an integer ',
              'name': 'cosine',
              'parameters': {'properties': {'x': {'description': 'The input '
                                                                 'integer',
                                                  'title': 'X',
                                                  'type': 'integer'}},
                             'required': ['x'],
                             'title': 'cosine',
                             'type': 'object'}},
 'type': 'function'}


In [73]:
messages = []

system_message = {"role": "system", "content": "You are a helpful customer support assistant. Use the supplied tools to assist the user."}
user_message =  {"role": "user", "content": "Can you calculate cos(3) for me?"}

messages.append(system_message)
messages.append(user_message)

response = chat.completions.create( 
        messages = messages,
        model_name = "gpt-4",
        tools=tools
)

response_message = response.choices[0].message

In [74]:
#print(response)
pprint(response_message.to_dict())

{'content': None,
 'role': 'assistant',
 'tool_calls': [{'function': {'arguments': '{"x":3}', 'name': 'cosine'},
                 'id': 'call_8P56uYMn0GcbJL3OHHZepu1A',
                 'type': 'function'}]}


In [75]:
def handle_response(response):
    import json
    
    response_message = response.choices[0].message
    
    tool_call_block = response_message.tool_calls[0]
    if tool_call_block:
        messages.append({"role": "assistant", "tool_calls": response_message.tool_calls})
        tool_call_id = tool_call_block.id
        tool_call_function_name = tool_call_block.function.name
        tool_call_function_arguments = json.loads(tool_call_block.function.arguments)
        
        
        if tool_call_function_name == 'cosine':
            x = tool_call_function_arguments.get('x')
            tool_call_result = cosine(x),
            function_call_result_message = {
                "role" : "tool", 
                "content": json.dumps({
                    "x": 3, 
                    "cos(x)": tool_call_result
                }), 
            "tool_call_id": tool_call_id
            }
            messages.append(function_call_result_message)  

In [76]:
handle_response(response)

print(messages)

[{'role': 'system', 'content': 'You are a helpful customer support assistant. Use the supplied tools to assist the user.'}, {'role': 'user', 'content': 'Can you calculate cos(3) for me?'}, {'role': 'assistant', 'tool_calls': [ChatCompletionMessageToolCall(id='call_8P56uYMn0GcbJL3OHHZepu1A', function=Function(arguments='{"x":3}', name='cosine'), type='function')]}, {'role': 'tool', 'content': '{"x": 3, "cos(x)": [-0.9899924966004454]}', 'tool_call_id': 'call_8P56uYMn0GcbJL3OHHZepu1A'}]


In [77]:
response = chat.completions.create( 
        messages = messages,
        model_name = "gpt-4",
        tools=tools
)

In [78]:
print(response)

ChatCompletion(id='chatcmpl-AaRqLZsIvp4TDz5hTLnLyh53K4pJY', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='The value of cos(3) is approximately -0.9899924966004454.', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None), content_filter_results={'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}})], created=1733250201, model='gpt-4-turbo-2024-04-09', object='chat.completion', service_tier=None, system_fingerprint='fp_5603ee5e2e', usage=CompletionUsage(completion_tokens=20, prompt_tokens=116, total_tokens=136, completion_tokens_details=None, prompt_tokens_details=None), prompt_filter_results=[{'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'sev