In [1]:
# !pip install langchain-openai
from langchain_core.tools import tool 
from langchain_core.messages import SystemMessage, trim_messages, AIMessage, HumanMessage, ToolMessage, AIMessageChunk
from typing import Callable
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI, AzureChatOpenAI
# AzureChatOpenAIで、トークン計算をする場合は、modelパラメータに使っているモデル情報を与えないといけない
# これは、tiktokenによってパースされるため、それに準拠する必要がある。
azure_llm = AzureChatOpenAI(azure_deployment='test',
                      api_version='2024-02-15-preview',
                      azure_endpoint='https://test.com',
                      model='gpt-4o'
                      )

load_dotenv()

@tool 
def add(a: int, b: int) -> int:
    """Adds a and b."""
    return a + b 

@tool 
def multiply(a: int, b: int) -> int:
    """Multiplies a and b."""
    return a * b 

tools = [add, multiply] 
tool_dict = {'add': add, 
            'multiply': multiply}

def tool_call_from_ai_message(tool_dict: dict[str, Callable], tool_chunks: AIMessageChunk, messages: list):
    messages_temp = messages.copy()
    for tool_info in tool_chunks.tool_calls:
        print(f'CALL: {tool_info["name"]}, {tool_info["args"]}')
        tool_output = tool_dict[tool_info['name'].lower()].invoke(tool_info['args'])
        messages_temp.append(ToolMessage(tool_output, tool_call_id=tool_info['id']))
    return messages_temp

In [2]:
llm = ChatOpenAI(model='gpt-4o-mini', temperature=0.0,
                 max_tokens=100) 
llm_with_tools = llm.bind_tools(tools) 

trimmer = trim_messages(
    max_tokens=10_000,
    strategy='last',
    token_counter=llm_with_tools,
    include_system=True,
    allow_partial=False # コンテンツの途中で切って、調整する場合(lastから数えて、max_tokensが入るところで切れる)
)



In [3]:
query = 'What is 3 * 12? Also, what is 11 + 49?'

# query = '3*(12 +30)は？'
messages = [
    HumanMessage(query)
]

is_answer = False

# AI が関数呼び出しではなく、回答を生成したら終わり
while not is_answer:
    aggregate = None 
    is_finish = False
    finish_reason = None 
    async for chunk in llm_with_tools.astream(messages, stream_usage=True):
        
        aggregate = chunk if aggregate is None else aggregate + chunk 
        if is_finish:
            messages.append(aggregate)
            
            if finish_reason == 'tool_calls': 
                print('Tool Calling Start!!')
                messages = tool_call_from_ai_message(tool_dict, aggregate, messages)
            else:
                print()
                print(f'finish_reason: {finish_reason}')
                is_answer = True 
            break 
        
        if ('finish_reason' in chunk.response_metadata):
            finish_reason = chunk.response_metadata['finish_reason'] 
            is_finish = True 
            
        
        if len(chunk.content) == 0 and len(chunk.tool_call_chunks) == 0:
            continue 
        
        if len(chunk.tool_call_chunks) > 0:
            continue 
        
        # AI回答文
        print(chunk.content, end='')
    
    if is_answer:
        break 

Tool Calling Start!!
CALL: multiply, {'a': 3, 'b': 12}
CALL: add, {'a': 11, 'b': 49}
The result of \(3 \times 12\) is 36, and the result of \(11 + 49\) is 60.
finish_reason: stop


In [4]:
ai_message = messages[1]

In [8]:
print(ai_message.json(indent=2))

{
  "content": "",
  "additional_kwargs": {
    "tool_calls": [
      {
        "index": 0,
        "id": "call_to2SgrJZunK8sfdhwqPWjqtS",
        "function": {
          "arguments": "{\"a\": 3, \"b\": 12}",
          "name": "multiply"
        },
        "type": "function"
      },
      {
        "index": 1,
        "id": "call_QA7cXz6hjb0EIyNCa6HIBNkF",
        "function": {
          "arguments": "{\"a\": 11, \"b\": 49}",
          "name": "add"
        },
        "type": "function"
      }
    ]
  },
  "response_metadata": {
    "finish_reason": "tool_calls",
    "model_name": "gpt-4o-mini-2024-07-18",
    "system_fingerprint": "fp_483d39d857"
  },
  "type": "AIMessageChunk",
  "name": null,
  "id": "run-3b2d44f4-49f7-4402-82b3-d7fcf0c43a33",
  "example": false,
  "tool_calls": [
    {
      "name": "multiply",
      "args": {
        "a": 3,
        "b": 12
      },
      "id": "call_to2SgrJZunK8sfdhwqPWjqtS",
      "type": "tool_call"
    },
    {
      "name": "add",
      "ar

In [9]:
print(messages[2].json(indent=2))

{
  "content": "36",
  "additional_kwargs": {},
  "response_metadata": {},
  "type": "tool",
  "name": null,
  "id": null,
  "tool_call_id": "call_to2SgrJZunK8sfdhwqPWjqtS",
  "artifact": null,
  "status": "success"
}


In [18]:

trimmer = trim_messages(
    max_tokens=100,
    strategy='last',
    token_counter=llm_with_tools,
    include_system=True,
    allow_partial=False # コンテンツの途中で切って、調整する場合(lastから数えて、max_tokensが入るところで切れる)
)

messages = [
    SystemMessage("Youa are AI Assistant"),
    HumanMessage("What is human?"),
    AIMessage("Human is human"),
    HumanMessage("What is 3*12?")
]

trimmer.invoke(messages,)



[SystemMessage(content='Youa are AI Assistant'),
 HumanMessage(content='What is human?'),
 AIMessage(content='Human is human'),
 HumanMessage(content='What is 3*12?')]

In [173]:
type(HumanMessage), type(AIMessage)

(pydantic.v1.main.ModelMetaclass, pydantic.v1.main.ModelMetaclass)

In [10]:
a = {'a': 12, 'b': 32}
str(a)

"{'a': 12, 'b': 32}"

In [11]:
import json 
json.dumps(a)

'{"a": 12, "b": 32}'