# 基于 Chat Completions API 实现外部函数调用

In [45]:
import json
import token

import requests
import os
from tenacity import retry, stop_after_attempt, wait_random_exponential
from termcolor import colored

GPT_MODEL = "gpt-3.5-turbo"

## 定义工具函数

In [46]:
@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, functions = None, function_call = None, model = GPT_MODEL):
    headers = {
        'Content-Type': 'application/json',
        'Authorization': 'Bearer ' + os.getenv("OPENAI_API_KEY"),
    }
    
    json_data = {"model": model, "messages": messages}
    
    if functions is not None:
        json_data.update({"functions": functions})
        
    if function_call is not None:
        json_data.update({"function_call": function_call})
        
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=json_data,
        )
        
        return response
    except Exception as e:
        print("Unable to generate chat completion request")
        print(f"Exception: {e}")
        return e
        

In [74]:
# 定义一个颜色打印消息对话内容
def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "function": "magenta"
    }
    
    for message in messages:
        if message["role"] == "system":
            print(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
            
        elif message["role"] == "user":
            print(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
            
        elif message["role"] == "assistant" and message.get("function_call"):
            print(colored(f"assistant[function_call]: {message['function_call']}\n", role_to_color[message["role"]]))
            
        elif message["role"] == "assistant" and not message.get("function_call"):
            print(colored(f"assistant[content]: {message['content']}\n", role_to_color[message["role"]]))
            
        elif message["role"] == "function":
            print(colored(f"function: ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))

In [75]:
# 定义functions
functions = [
    {
        "name": "get_current_weather",
        "description": "Get current weather conditions",
        "parameters": {
            "type": "object",
            "properties": {
                "location": { # 地点
                    "type": "string",
                    "description": "Location of the weather conditions, the city and country, such as 'New York'",
                },
                "format":{ # 温度
                    "type": "string",
                    "enum": ["celsius", "fahrenheit"],
                    "description": "The temperature unit to use. Infer this from the users location.",
                }
            },
            "required": ["location", "format"]
        }
    },
    {
        "name": "get_n_day_weather_forecast",
        "description": "Get an N-day weather forecast",
        "parameters": {
            "type": "object",
            "properties": {
                "location": { 
                    "type": "string",
                    "description": "Location of the weather forecast, such as 'New York'",
                },
                "format":{ 
                    "type": "string",
                    "enum": ["celsius", "fahrenheit"],
                    "description": "The temperature unit to use. Infer this from the users location.",
                },
                "num_days":{ 
                    "type": "integer",
                    "description": "The number of days to forecast today.",
                }
            }
        },
        "required": ["location", "format", "num_days"]
    }
]

In [49]:
messages = []
messages.append(
    {
        "role": "system",
        "content": "不要对函数中要插入的值做出假设。如果用户的请求含糊不清，请求进一步澄清"
    }
)

messages.append(
    {
        "role": "user",
        "content": "今天天气如何？"
    }
)

chat_response = chat_completion_request(
    messages, functions = functions)

message = chat_response.json()["choices"][0]["message"]
messages.append(message)
pretty_print_conversation(messages)

[31msystem: 不要对函数中要插入的值做出假设。如果用户的请求含糊不清，请求进一步澄清
[0m
[32muser: 今天天气如何？
[0m
[34massistant[content]: 请问您需要查询哪个城市的天气？
[0m


In [50]:
messages.append(
    {
        "role": "user",
        "content": "我在中国的上海市"
    }
)

chat_response = chat_completion_request(messages, functions = functions)
message = chat_response.json()["choices"][0]["message"]
messages.append(message)
pretty_print_conversation(messages)

[31msystem: 不要对函数中要插入的值做出假设。如果用户的请求含糊不清，请求进一步澄清
[0m
[32muser: 今天天气如何？
[0m
[34massistant[content]: 请问您需要查询哪个城市的天气？
[0m
[32muser: 我在中国的上海市
[0m
[34massistant[function_call]: {'name': 'get_current_weather', 'arguments': '{"location":"Shanghai","format":"celsius"}'}
[0m


In [51]:
messages = []
messages.append(
    {
        "role": "system",
        "content": "不要对函数中要插入的值做出假设。如果用户的请求含糊不清，请求进一步澄清"
    }
)

messages.append(
    {
        "role": "user",
        "content": "请向我提供苏州今天的天气(使用Celsius)？"
    }
)

chat_response = chat_completion_request(
    messages, functions = functions)

message = chat_response.json()["choices"][0]["message"]
messages.append(message)
pretty_print_conversation(messages)

[31msystem: 不要对函数中要插入的值做出假设。如果用户的请求含糊不清，请求进一步澄清
[0m
[32muser: 请向我提供苏州今天的天气(使用Celsius)？
[0m
[34massistant[function_call]: {'name': 'get_current_weather', 'arguments': '{"location":"Suzhou","format":"celsius"}'}
[0m


In [52]:
# 初始化一个空的messages列表
messages = []

# 向messages列表添加一条系统角色的消息，要求不做关于函数参数值的假设，如果用户的请求模糊，应该寻求澄清
messages.append({
    "role": "system",  # 消息的角色是"system"
    "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."
})

# 向messages列表添加一条用户角色的消息，用户询问在未来x天内苏格兰格拉斯哥的天气情况
messages.append({
    "role": "user",  # 消息的角色是"user"
    "content": "what is the weather going to be like in Shanghai, China over the next x days"
})

# 使用定义的chat_completion_request函数发起一个请求，传入messages和functions作为参数
chat_response = chat_completion_request(
    messages, functions=functions
)

# 解析返回的JSON数据，获取助手的回复消息
assistant_message = chat_response.json()["choices"][0]["message"]

# 将助手的回复消息添加到messages列表中
messages.append(assistant_message)

# 打印助手的回复消息
pretty_print_conversation(messages)

[31msystem: Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.
[0m
[32muser: what is the weather going to be like in Shanghai, China over the next x days
[0m
[34massistant[function_call]: {'name': 'get_n_day_weather_forecast', 'arguments': '{"location":"Shanghai, China"}'}
[0m


In [55]:
# 创建另一个空的消息列表
messages = []

# 添加系统角色的消息
messages.append({
    "role": "system",  # 角色为系统
    "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."
})

# 添加用户角色的消息
messages.append({
    "role": "user",  # 角色为用户
    "content": "Give me the current weather (use Celcius) for Toronto, Canada."
})

# 使用定义的chat_completion_request函数发起一个请求，传入messages、functions和function_call作为参数
chat_response = chat_completion_request(
    messages, functions=functions, function_call="none"
)

# 解析返回的JSON数据，获取第一个选项
assistant_message = chat_response.json()["choices"][0]["message"]

# 将助手的回复消息添加到messages列表中
messages.append(assistant_message)

# 打印助手的回复消息
pretty_print_conversation(messages)

[31msystem: Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.
[0m
[32muser: Give me the current weather (use Celcius) for Toronto, Canada.
[0m
[34massistant[content]: Sure! Let me get the current weather conditions in Toronto, Canada using Celsius as the temperature unit.
[0m


## 使用GPT模型执行生成的函数

In [57]:
import sqlite3
conn = sqlite3.connect("data/chinook.db")
print("DB Opened.")

DB Opened.


In [62]:
def get_table_names(conn):
    table_names = []
    tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
    
    for table in tables.fetchall():
        table_names.append(table[0])
    return table_names

def get_column_names(conn, table_name):
    column_names = []
    columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    
    for col in columns:
        column_names.append(col[1])
    
    return column_names

def get_database_info(conn):
    table_dicts = []
    for table_name in get_table_names(conn):
        column_names = get_column_names(conn, table_name)
        table_dicts.append({"table_name": table_name, "column_names": column_names})
    
    return table_dicts

In [63]:
get_database_info(conn)

[{'table_name': 'albums', 'column_names': ['AlbumId', 'Title', 'ArtistId']},
 {'table_name': 'sqlite_sequence', 'column_names': ['name', 'seq']},
 {'table_name': 'artists', 'column_names': ['ArtistId', 'Name']},
 {'table_name': 'customers',
  'column_names': ['CustomerId',
   'FirstName',
   'LastName',
   'Company',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email',
   'SupportRepId']},
 {'table_name': 'employees',
  'column_names': ['EmployeeId',
   'LastName',
   'FirstName',
   'Title',
   'ReportsTo',
   'BirthDate',
   'HireDate',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email']},
 {'table_name': 'genres', 'column_names': ['GenreId', 'Name']},
 {'table_name': 'invoices',
  'column_names': ['InvoiceId',
   'CustomerId',
   'InvoiceDate',
   'BillingAddress',
   'BillingCity',
   'BillingState',
   'BillingCountry',
   'BillingPostalCode',
   'Total']},
 {'table_name': 'invoice_i

In [65]:
db_dict = get_database_info(conn)
db_string = "\n".join(
    [
        f"Table: {table['table_name']}\nColunms: {','.join(table['column_names'])}" for table in db_dict
    ]
)

In [66]:
db_string

'Table: albums\n Colunms: AlbumId,Title,ArtistId\nTable: sqlite_sequence\n Colunms: name,seq\nTable: artists\n Colunms: ArtistId,Name\nTable: customers\n Colunms: CustomerId,FirstName,LastName,Company,Address,City,State,Country,PostalCode,Phone,Fax,Email,SupportRepId\nTable: employees\n Colunms: EmployeeId,LastName,FirstName,Title,ReportsTo,BirthDate,HireDate,Address,City,State,Country,PostalCode,Phone,Fax,Email\nTable: genres\n Colunms: GenreId,Name\nTable: invoices\n Colunms: InvoiceId,CustomerId,InvoiceDate,BillingAddress,BillingCity,BillingState,BillingCountry,BillingPostalCode,Total\nTable: invoice_items\n Colunms: InvoiceLineId,InvoiceId,TrackId,UnitPrice,Quantity\nTable: media_types\n Colunms: MediaTypeId,Name\nTable: playlists\n Colunms: PlaylistId,Name\nTable: playlist_track\n Colunms: PlaylistId,TrackId\nTable: tracks\n Colunms: TrackId,Name,AlbumId,MediaTypeId,GenreId,Composer,Milliseconds,Bytes,UnitPrice\nTable: sqlite_stat1\n Colunms: tbl,idx,stat'

In [67]:
db_dict

[{'table_name': 'albums', 'column_names': ['AlbumId', 'Title', 'ArtistId']},
 {'table_name': 'sqlite_sequence', 'column_names': ['name', 'seq']},
 {'table_name': 'artists', 'column_names': ['ArtistId', 'Name']},
 {'table_name': 'customers',
  'column_names': ['CustomerId',
   'FirstName',
   'LastName',
   'Company',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email',
   'SupportRepId']},
 {'table_name': 'employees',
  'column_names': ['EmployeeId',
   'LastName',
   'FirstName',
   'Title',
   'ReportsTo',
   'BirthDate',
   'HireDate',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email']},
 {'table_name': 'genres', 'column_names': ['GenreId', 'Name']},
 {'table_name': 'invoices',
  'column_names': ['InvoiceId',
   'CustomerId',
   'InvoiceDate',
   'BillingAddress',
   'BillingCity',
   'BillingState',
   'BillingCountry',
   'BillingPostalCode',
   'Total']},
 {'table_name': 'invoice_i

In [85]:
functions = [
    {
        "name": "ask_db",
        "description": "Use this function to answer user questions about music. Output should be a fully formed SQL query.",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                            SQL query extracting info to answer the user's question.
                            SQL should be written using this db string:
                            {db_string}
                            The query should be returned in plain text, not in JSON.
                            """
                }
            },
            "required": ["query"]
        }
    }
]

In [86]:
def ask_db(conn, query):
    print(f"query: {query}")
    try:
        results = str(conn.execute(query).fetchall())
    except Exception as e:
        results = f"query failed with error: {e}"
    
    return results

def execute_function_call(message):
    if message["function_call"]["name"] == "ask_db":
        
        query = json.loads(message["function_call"]["arguments"])["query"]
        results = ask_db(conn, query)
    else:
        results = f"Error: {message['function_call']['name']} does not exist"
        
    return results



In [87]:
messages = []

messages.append({
    "role": "system",
    "content": "Answer user questions by generating SQL queries against the Chinook Music Database."
})

messages.append(
    {
        "role": "user",
        "content": "Hi, who are the top 5 artists by number of tracks?"
    }
)

chat_response = chat_completion_request(messages, functions)
# print(chat_response.json())
message = chat_response.json()["choices"][0]["message"]


messages.append(message)

if message.get("function_call"):
    results = execute_function_call(message)
    messages.append({"role": "function", "name": message["function_call"]["name"], "content": results})

pretty_print_conversation(messages)

query: SELECT artists.Name, COUNT(tracks.TrackId) AS TrackCount
FROM artists
JOIN albums ON artists.ArtistId = albums.ArtistId
JOIN tracks ON albums.AlbumId = tracks.AlbumId
GROUP BY artists.ArtistId
ORDER BY TrackCount DESC
LIMIT 5;
[31msystem: Answer user questions by generating SQL queries against the Chinook Music Database.
[0m
[32muser: Hi, who are the top 5 artists by number of tracks?
[0m
[34massistant[function_call]: {'name': 'ask_db', 'arguments': '{"query":"SELECT artists.Name, COUNT(tracks.TrackId) AS TrackCount\\nFROM artists\\nJOIN albums ON artists.ArtistId = albums.ArtistId\\nJOIN tracks ON albums.AlbumId = tracks.AlbumId\\nGROUP BY artists.ArtistId\\nORDER BY TrackCount DESC\\nLIMIT 5;"}'}
[0m
[35mfunction: (ask_db): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]
[0m


In [90]:
query = "SELECT artists.Name AS Artist, COUNT(tracks.TrackId) AS TrackCount FROM artists\n JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY TrackCount DESC LIMIT 5;"
ask_db(conn, query)

query: SELECT artists.Name AS Artist, COUNT(tracks.TrackId) AS TrackCount FROM artists
 JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY TrackCount DESC LIMIT 5;


"[('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]"

In [91]:
# 向消息列表中添加一个用户的问题，内容是 "What is the name of the album with the most tracks?"
messages.append({"role": "user", "content": "What is the name of the album with the most tracks?"})

# 使用 chat_completion_request 函数获取聊天响应
chat_response = chat_completion_request(messages, functions)

# 从聊天响应中获取助手的消息
message = chat_response.json()["choices"][0]["message"]

# 将助手的消息添加到消息列表中
messages.append(message)

# 如果助手的消息中有功能调用
if message.get("function_call"):
    # 使用 execute_function_call 函数执行功能调用，并获取结果
    results = execute_function_call(message)
    # 将功能的结果作为一个功能角色的消息添加到消息列表中
    messages.append({"role": "function", "content": results, "name": message["function_call"]["name"]})

# 使用 pretty_print_conversation 函数打印对话
pretty_print_conversation(messages)

query: SELECT albums.Title, COUNT(tracks.TrackId) AS TrackCount
FROM albums
JOIN tracks ON albums.AlbumId = tracks.AlbumId
GROUP BY albums.AlbumId
ORDER BY TrackCount DESC
LIMIT 1;
[31msystem: Answer user questions by generating SQL queries against the Chinook Music Database.
[0m
[32muser: Hi, who are the top 5 artists by number of tracks?
[0m
[34massistant[function_call]: {'name': 'ask_db', 'arguments': '{"query":"SELECT artists.Name, COUNT(tracks.TrackId) AS TrackCount\\nFROM artists\\nJOIN albums ON artists.ArtistId = albums.ArtistId\\nJOIN tracks ON albums.AlbumId = tracks.AlbumId\\nGROUP BY artists.ArtistId\\nORDER BY TrackCount DESC\\nLIMIT 5;"}'}
[0m
[35mfunction: (ask_db): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]
[0m
[32muser: What is the name of the album with the most tracks?
[0m
[34massistant[function_call]: {'name': 'ask_db', 'arguments': '{"query":"SELECT albums.Title, COUNT(tracks.TrackId) AS TrackCount