In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="7"
import sys
sys.path.append("ToolAlpaca")
from tqdm import tqdm
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain import HuggingFacePipeline
from pprint import pprint
from agent.custom_parser import CustomMRKLOutputParser2, CustomMRKLOutputParser, CustomMRKLOutputParser3
from copy import deepcopy
from utils import load_openapi_spec, escape
from agent.tools import Tool, GetDetailsTool, tool_projection
from agent.tools import CustomInvalidTool
from transformers import StoppingCriteria
import torch
from langchain.schema import AgentAction, AgentFinish
from collections import Counter
import random
random.seed(0)

In [3]:
data = json.load(open('data/train_data_val_part', 'r'))
generations = json.load(open('data/test_generations_val_model_trained_on_half_train.json', 'r'))
assert len(data) == len(generations)

In [7]:
parser = CustomMRKLOutputParser2()
parser.parse(generations[0][0][1])

AgentAction(tool='getSubdomains', tool_input='{"domain": "example.com"}', log=' I need to use the getSubdomains tool to retrieve a list of subdomains for the domain example.com.\nASSISTANT Action: getSubdomains\nASSISTANT Action Input: {"domain": "example.com"}')

In [8]:
parser.parse(data[0][0][1])

AgentAction(tool='getSubdomains', tool_input='{"domain": "example.com"}\nASSISTANT Observation:', log='I need to use the getSubdomains tool to find the subdomains for example.com.\nASSISTANT Action: getSubdomains\nASSISTANT Action Input: {"domain": "example.com"}\nASSISTANT Observation: ')

In [9]:
test_response = data[2][0][1]
parser.parse(test_response)

AgentAction(tool='getTechnologies', tool_input='{"domain": "amazon.com"}\nASSISTANT Observation:', log='I need to use the getTechnologies tool to retrieve the technologies used by amazon.com.\nASSISTANT Action: getTechnologies\nASSISTANT Action Input: {"domain": "amazon.com"}\nASSISTANT Observation: ')

Смимикрируем генерацию ответа от тула

In [10]:
original_full_train = json.load(open('data/train_data.json', 'r'))

In [11]:
def get_all_tools(item, enable_getDetails=False, server_url="http://127.0.0.1:5679"):
    openapi_spec = load_openapi_spec(item["Documentation"], replace_refs=True)
    components_descriptions = escape(item["Function_Description"]["components"])

    tools = [GetDetailsTool()] if not enable_getDetails else []
    for ext_tool in item.get("external_tools", []):
        tools.append(tool_projection[ext_tool]())

    for idx, func_name in enumerate(item["Function_Projection"]):
        description = escape(item["Function_Description"][func_name])
        if idx == len(item["Function_Projection"]) - 1:
            description += components_descriptions
        path, method = item["Function_Projection"][func_name]
        tools.append(Tool(
            base_url=server_url + "/" + item["Name"] if server_url else None,
            func_name=func_name,
            openapi_spec=openapi_spec,
            path=path,
            method=method,
            description=description,
            retrieval_available="retrieval" in item.get("external_tools", [])
        ))
    tools = {tool.name: tool for tool in tools}
    return tools

In [12]:
get_all_tools(original_full_train[0])

{'getDetails': GetDetailsTool(name='getDetails', description='If the user\'s question lacks the essential information needed to answer the question effectively, or if the question contains vague terms or pronouns without sufficient context, invoke the `getDetails` function to prompt the user for the missing critical details. However, `getDetails` should not be used in cases where the user omits optional parameters, unless these parameters become necessary in the course of the conversation. \nParameters: {{"Question": "The question to prompt user to provide sufficient information."}}\nOutput: User\'s response.', args_schema=None, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x7fd7c50a2530>, chat_history=[], func=<function GetDetailsTool.__init__.<locals>.func at 0x7fd7bd4053f0>, coroutine=None),
 'sendHttpRequest': Tool(name='sendHttpRequest', description='Send an HTTP request with the specified method, headers, and dat

In [13]:
query_tool_mapping = dict()
for item in original_full_train:
    tools = get_all_tools(item)
    for query in item['Instructions']:
        query_tool_mapping[query] = tools

Проитерируемся по всем генерациям и попробуем повызывать тулы

In [14]:
tools

{'getDetails': GetDetailsTool(name='getDetails', description='If the user\'s question lacks the essential information needed to answer the question effectively, or if the question contains vague terms or pronouns without sufficient context, invoke the `getDetails` function to prompt the user for the missing critical details. However, `getDetails` should not be used in cases where the user omits optional parameters, unless these parameters become necessary in the course of the conversation. \nParameters: {{"Question": "The question to prompt user to provide sufficient information."}}\nOutput: User\'s response.', args_schema=None, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x7fd7c50a2530>, chat_history=[], func=<function GetDetailsTool.__init__.<locals>.func at 0x7fd7c0d4cdc0>, coroutine=None),
 'getVerse': Tool(name='getVerse', description='Retrieve the text of a specific verse from the Bible.\nParameters: {{"book": "

In [15]:
def parse_and_call_tool(prediction, tools):
    try:
        parsed = parser.parse(prediction)
        if isinstance(parsed, AgentFinish):
            return parsed.return_values['output']
        if parsed.tool in tools:
            tool = tools[parsed.tool]
            res = tool._run(parsed.tool_input)
        else:
            tool = CustomInvalidTool()
            res = tool._run(parsed.tool_input, tools)
        return res
    except Exception as e:
        return str(e)

def check_error(response):
    if 'Parameter type error' in response or 'Missing required parameters:' in response or 'Could not parse LLM output' in response:
        return True
    return False

def parse_query(meta_prompt):
    return meta_prompt.split('Begin!\n\nUSER: ')[1].split('\nASSISTANT Thought:')[0]

max_len = len(generations)
errors = []
for idx, (item_gt, item_pred) in enumerate(tqdm(zip(data[:max_len], generations))):
    query = parse_query(item_gt[0][0])
    tools = query_tool_mapping[query]

    flag = True
    for step_idx in range(len(item_gt[0])):
        if item_gt[1][step_idx]:
            res = parse_and_call_tool(item_gt[0][step_idx], tools)
            if check_error(res):
                flag = False

    if flag:
        for step_idx in range(len(item_gt[0])):
            if item_gt[1][step_idx]:
                res_pred = parse_and_call_tool(item_pred[0][step_idx], tools)
                res_gt = parse_and_call_tool(item_gt[0][step_idx], tools)
                if not check_error(res_gt) and check_error(res_pred):
                    errors.append((item_gt, item_pred, (step_idx, res_gt, res_pred)))

1676it [00:26, 64.00it/s] 


In [18]:
errors[0]

([['A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions with the help of some tools.\nYou have access to the following tools:\n\ngetDetails: If the user\'s question lacks the essential information needed to answer the question effectively, or if the question contains vague terms or pronouns without sufficient context, invoke the `getDetails` function to prompt the user for the missing critical details. However, `getDetails` should not be used in cases where the user omits optional parameters, unless these parameters become necessary in the course of the conversation. \nParameters: {"Question": "The question to prompt user to provide sufficient information."}\nOutput: User\'s response.\nsearch: Search for internet assets based on various parameters\nParameters: {"query": "Required. string. The search query. Can be a domain name, IP address, SSL certificate information, or other paramet

In [80]:
len(errors)

197

In [81]:
def construct_correction_example(error):
    error_idx = error[-1][0]
    gt_chain = error[0][0]
    pred_chain = error[1][0]
    gt_mask_flags = error[0][1]
    error_pred = error[-1][-1]
    new_chain = gt_chain[:error_idx] + [pred_chain[error_idx], error_pred] + gt_chain[error_idx:]
    new_mask_flags = gt_mask_flags[:error_idx] + [False, False] + gt_mask_flags[error_idx:]
    return [new_chain, new_mask_flags]

In [82]:
construct_correction_example(errors[105])

[['A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions with the help of some tools.\nYou have access to the following tools:\n\ngetDetails: If the user\'s question lacks the essential information needed to answer the question effectively, or if the question contains vague terms or pronouns without sufficient context, invoke the `getDetails` function to prompt the user for the missing critical details. However, `getDetails` should not be used in cases where the user omits optional parameters, unless these parameters become necessary in the course of the conversation. \nParameters: {"Question": "The question to prompt user to provide sufficient information."}\nOutput: User\'s response.\nrequestData: Request data from any API or data source and securely transmit it to a smart contract on the blockchain.\nParameters: {"endpoint": "string. The API endpoint to request data from.", "method":

In [83]:
correction_examples = []
for item in errors:
    correction_examples.append(construct_correction_example(item))

In [84]:
len(correction_examples)

197

In [153]:
# json.dump(correction_examples, open('../data/correction_examples_val_model_trained_on_half_train.json', 'w'))

Попробуем почистить трейн выборку от ошибок новым способом

In [92]:
full_train_preprocessed = json.load(open('../data/train_data_train_part', 'r'))

In [93]:
len(full_train_preprocessed)

2261

In [94]:
cleaned_train = []
for item in tqdm(full_train_preprocessed):
    query = parse_query(item[0][0])
    tools = query_tool_mapping[query]
    flag = True
    for step_idx in range(len(item[0])):
        if item[1][step_idx]:
            res = parse_and_call_tool(item[0][step_idx], tools)
            if check_error(res):
                print(res)
                flag = False
    if flag:
        cleaned_train.append(item)

  6%|███▋                                                        | 140/2261 [00:00<00:06, 330.52it/s]

Parameter type error: "fontType", expected one of ['Arial', 'Helvetica', 'Times New Roman', 'Courier New', 'Verdana'], but got "Courier". You need to change the input and try again.
Parameter type error: "color", expected one of ['black', 'white', 'red', 'green', 'blue'], but got "orange". You need to change the input and try again.
Parameter type error: "color", expected one of ['black', 'white', 'red', 'green', 'blue'], but got "orange". You need to change the input and try again.
Parameter type error: "fontType", expected one of ['Arial', 'Helvetica', 'Times New Roman', 'Courier New', 'Verdana'], but got "Roboto". You need to change the input and try again.
Parameter type error: "color", expected one of ['black', 'white', 'red', 'green', 'blue'], but got "purple". You need to change the input and try again.
Parameter type error: "color", expected one of ['black', 'white', 'red', 'green', 'blue'], but got "user-selected-color". You need to change the input and try again.
Parameter ty

 12%|███████▍                                                    | 281/2261 [00:00<00:06, 329.27it/s]

Parameter type error: "countryCode", expected string, but got array. You need to change the input and try again.
Parameter type error: "city", expected string, but got integer. You need to change the input and try again.
Parameter type error: "city", expected string, but got integer. You need to change the input and try again.
Parameter type error: "type", expected one of ['education', 'recreational', 'social', 'diy', 'charity', 'cooking', 'relaxation', 'music', 'busywork'], but got "romantic". You need to change the input and try again.


 16%|█████████▍                                                  | 355/2261 [00:01<00:06, 305.05it/s]

Parameter type error: "parkingId", expected integer, but got string. You need to change the input and try again.
Parameter type error: "length", expected integer, but got string. You need to change the input and try again.
Parameter type error: "limit", expected integer, but got string. You need to change the input and try again.


 19%|███████████▍                                                | 429/2261 [00:01<00:05, 325.68it/s]

Parameter type error: "routeNumber", expected integer, but got None. You need to change the input and try again.
Parameter type error: "id", expected integer, but got string. You need to change the input and try again.
Parameter type error: "id", expected integer, but got string. You need to change the input and try again.
Parameter type error: "limit", expected integer, but got None. You need to change the input and try again.
Parameter type error: "surfaceType", expected one of ['grass', 'asphalt', 'concrete', 'water', 'snow'], but got "well-lit". You need to change the input and try again.
Parameter type error: "facilities", expected string, but got array. You need to change the input and try again.
Parameter type error: "type", expected one of ['accident', 'construction', 'road closure', 'weather'], but got "accident, construction". You need to change the input and try again.


 22%|█████████████▎                                              | 500/2261 [00:01<00:05, 330.29it/s]

Parameter type error: "industry", expected one of ['construction', 'healthcare', 'education', 'transportation'], but got "pharmaceutical". You need to change the input and try again.
Parameter type error: "range", expected string, but got integer. You need to change the input and try again.
Parameter type error: "range", expected one of ['10', '20', '50', '100'], but got "10". You need to change the input and try again.


 27%|███████████████▉                                            | 602/2261 [00:01<00:05, 321.97it/s]

Parameter type error: "category", expected one of ['health', 'education', 'environment', 'transportation', 'finance', 'public safety'], but got "healthcare". You need to change the input and try again.
Parameter type error: "region", expected one of ['us', 'gb', 'ca', 'au', 'in'], but got "eu". You need to change the input and try again.
Parameter type error: "region", expected one of ['us', 'gb', 'ca', 'au', 'in'], but got "in,gb,de,fr,es". You need to change the input and try again.
Parameter type error: "region", expected one of ['us', 'gb', 'ca', 'au', 'in'], but got "south america". You need to change the input and try again.
Parameter type error: "region", expected one of ['us', 'gb', 'ca', 'au', 'in'], but got "south america". You need to change the input and try again.
Parameter type error: "region", expected one of ['us', 'gb', 'ca', 'au', 'in'], but got "eu". You need to change the input and try again.
Parameter type error: "region", expected one of ['us', 'gb', 'ca', 'au', '

 32%|██████████████████▉                                         | 713/2261 [00:02<00:04, 330.62it/s]

Parameter type error: "limit", expected integer, but got string. You need to change the input and try again.
Parameter type error: "characterId", expected integer, but got string. You need to change the input and try again.


 36%|█████████████████████▌                                      | 814/2261 [00:02<00:05, 284.31it/s]

Parameter type error: "format", expected one of ['4cs', '6502acme', '6502kickass', '6502tasm', 'abap', 'actionscript', 'actionscript3', 'ada', 'aimms', 'algol68', 'apache', 'applescript', 'apt_sources', 'arm', 'asm', 'asp', 'asymptote', 'autoconf', 'autohotkey', 'autoit', 'avisynth', 'awk', 'bascomavr', 'bash', 'basic4gl', 'batch', 'bf', 'bibtex', 'blitzbasic', 'bnf', 'boo', 'brainfuck', 'bro', 'c', 'c_loadrunner', 'c_mac', 'caddcl', 'cadlisp', 'cfdg', 'chaiscript', 'chapel', 'cil', 'clojure', 'cmake', 'cobol', 'coffeescript', 'coldfusion', 'csharp', 'csp', 'css', 'cuesheet', 'd', 'dart', 'dcl', 'dcpu16', 'dcs', 'delphi', 'diff', 'div', 'dos', 'dot', 'e', 'ecmascript', 'eiffel', 'email', 'epc', 'erlang', 'euphoria', 'f#', 'falcon', 'filemaker', 'fo', 'f1', 'fortran', 'freebasic', 'freeswitch', 'gambas', 'gdb', 'genero', 'genie', 'gettext', 'glsl', 'gml', 'gnuplot', 'go', 'groovy', 'gwbasic', 'haskell', 'haxe', 'hicest', 'hq9plus', 'html4strict', 'html5', 'icon', 'idl', 'ini', 'inno', '

 40%|████████████████████████▏                                   | 910/2261 [00:02<00:04, 292.94it/s]

Parameter type error: "limit", expected integer, but got None. You need to change the input and try again.
Parameter type error: "offset", expected integer, but got None. You need to change the input and try again.
Parameter type error: "maxProtein", expected integer, but got None. You need to change the input and try again.
Parameter type error: "minFat", expected integer, but got None. You need to change the input and try again.
Parameter type error: "maxFat", expected integer, but got None. You need to change the input and try again.
Parameter type error: "minCarbs", expected integer, but got None. You need to change the input and try again.
Parameter type error: "maxProtein", expected integer, but got None. You need to change the input and try again.
Parameter type error: "maxFat", expected integer, but got None. You need to change the input and try again.
Parameter type error: "width", expected integer, but got None. You need to change the input and try again.
Parameter type error

 44%|██████████████████████████                                 | 1000/2261 [00:03<00:04, 258.58it/s]

Parameter type error: "skillLevel", expected one of ['entry-level', 'mid-level', 'senior-level'], but got "expert-level". You need to change the input and try again.
Parameter type error: "skillLevel", expected one of ['entry-level', 'mid-level', 'senior-level'], but got "all". You need to change the input and try again.
Parameter type error: "skillLevel", expected one of ['entry-level', 'mid-level', 'senior-level'], but got "intermediate-level". You need to change the input and try again.
Parameter type error: "skillLevel", expected one of ['entry-level', 'mid-level', 'senior-level'], but got "expert-level". You need to change the input and try again.


 48%|████████████████████████████▎                              | 1085/2261 [00:03<00:04, 250.69it/s]

Parameter type error: "limit", expected integer, but got string. You need to change the input and try again.
Parameter type error: "offset", expected integer, but got string. You need to change the input and try again.
Parameter type error: "routeId", expected integer, but got string. You need to change the input and try again.


 53%|███████████████████████████████                            | 1192/2261 [00:04<00:04, 251.64it/s]

Parameter type error: "rarity", expected integer, but got string. You need to change the input and try again.
Parameter type error: "monsterId", expected integer, but got string. You need to change the input and try again.


 54%|███████████████████████████████▊                           | 1219/2261 [00:04<00:04, 254.67it/s]

Parameter type error: "preferences", expected object, but got string. You need to change the input and try again.
Missing required parameters: "targetCurrency". You need to change the input and try again.
Parameter type error: "interval", expected one of ['1min', '5min', '15min', '30min', '60min', 'daily', 'weekly', 'monthly'], but got "hourly". You need to change the input and try again.
Missing required parameters: "playerName". You need to change the input and try again.


 65%|██████████████████████████████████████▍                    | 1472/2261 [00:05<00:02, 273.55it/s]

Parameter type error: "policyType", expected one of ['Lockdown', 'Travel restrictions', 'Social distancing', 'Vaccination', 'Testing', 'Contact tracing', 'Quarantine'], but got "Support for businesses". You need to change the input and try again.
Parameter type error: "policyType", expected one of ['Lockdown', 'Travel restrictions', 'Social distancing', 'Vaccination', 'Testing', 'Contact tracing', 'Quarantine'], but got "Mask mandates". You need to change the input and try again.
Parameter type error: "latitude", expected number, but got string. You need to change the input and try again.
Parameter type error: "longitude", expected number, but got string. You need to change the input and try again.
Parameter type error: "filters", expected object, but got string. You need to change the input and try again.
Parameter type error: "filters", expected object, but got string. You need to change the input and try again.
Parameter type error: "episodeId", expected integer, but got string. You

 69%|████████████████████████████████████████▋                  | 1557/2261 [00:05<00:02, 240.29it/s]

Parameter type error: "rarity", expected one of ['Common', 'Uncommon', 'Rare', 'Rare Holo', 'Rare Holo EX', 'Rare Ultra', 'Rare Secret', 'Rare Rainbow', 'Rare Prism', 'Rare ACE', 'Rare BREAK', 'Rare Holo GX', 'Rare Holo V', 'Rare Holo VMAX'], but got "Holo Rare". You need to change the input and try again.
Parameter type error: "type", expected one of ['Grass', 'Fire', 'Water', 'Lightning', 'Psychic', 'Fighting', 'Darkness', 'Metal', 'Fairy', 'Dragon', 'Colorless'], but got "Electric". You need to change the input and try again.
Parameter type error: "rarity", expected one of ['Common', 'Uncommon', 'Rare', 'Rare Holo', 'Rare Holo EX', 'Rare Ultra', 'Rare Secret', 'Rare Rainbow', 'Rare Prism', 'Rare ACE', 'Rare BREAK', 'Rare Holo GX', 'Rare Holo V', 'Rare Holo VMAX'], but got "Secret Rare". You need to change the input and try again.
Parameter type error: "productName", expected string, but got array. You need to change the input and try again.
Parameter type error: "upcCode", expected 

 74%|███████████████████████████████████████████▍               | 1664/2261 [00:05<00:01, 313.27it/s]

Parameter type error: "region", expected one of ['england-and-wales', 'scotland', 'northern-ireland'], but got "Northern Ireland". You need to change the input and try again.
Parameter type error: "region", expected one of ['england-and-wales', 'scotland', 'northern-ireland'], but got "Northern Ireland". You need to change the input and try again.


 83%|█████████████████████████████████████████████████          | 1879/2261 [00:06<00:01, 338.48it/s]

Parameter type error: "type", expected one of ['acts', 'regulations', 'bills'], but got "education". You need to change the input and try again.
Parameter type error: "id", expected integer, but got string. You need to change the input and try again.
Parameter type error: "feed_id", expected integer, but got string. You need to change the input and try again.


 86%|██████████████████████████████████████████████████▉        | 1950/2261 [00:06<00:00, 345.16it/s]

Parameter type error: "mmsi", expected integer, but got string. You need to change the input and try again.
Parameter type error: "lat", expected number, but got string. You need to change the input and try again.
Parameter type error: "lng", expected number, but got string. You need to change the input and try again.


 89%|████████████████████████████████████████████████████▋      | 2021/2261 [00:06<00:00, 333.34it/s]

Parameter type error: "lat", expected number, but got string. You need to change the input and try again.
Parameter type error: "lng", expected number, but got string. You need to change the input and try again.
Parameter type error: "lat", expected number, but got string. You need to change the input and try again.
Parameter type error: "lng", expected number, but got string. You need to change the input and try again.


 92%|██████████████████████████████████████████████████████▌    | 2089/2261 [00:07<00:00, 314.76it/s]

Parameter type error: "f_has_lyrics", expected string, but got integer. You need to change the input and try again.
Parameter type error: "f_has_lyrics", expected one of ['0', '1'], but got "1". You need to change the input and try again.
Parameter type error: "f_has_lyrics", expected string, but got integer. You need to change the input and try again.
Parameter type error: "f_has_lyrics", expected one of ['0', '1'], but got "1". You need to change the input and try again.
Parameter type error: "f_has_lyrics", expected string, but got integer. You need to change the input and try again.
Parameter type error: "f_has_lyrics", expected one of ['0', '1'], but got "1". You need to change the input and try again.
Parameter type error: "f_artist_id", expected integer, but got string. You need to change the input and try again.
Parameter type error: "f_music_genre_id", expected integer, but got string. You need to change the input and try again.
Parameter type error: "page_size", expected inte

 98%|█████████████████████████████████████████████████████████▌ | 2205/2261 [00:07<00:00, 345.69it/s]

Parameter type error: "type", expected one of ['rock', 'pop', 'hiphop', 'jazz', 'blues', 'country', 'classical', 'electronic', 'reggae', 'latin', 'metal', 'folk', 'indie', 'punk', 'rnb', 'soul', 'world', 'religious', 'kids', 'instrumental', 'other'], but got "upbeat". You need to change the input and try again.
Parameter type error: "type", expected one of ['rock', 'pop', 'hiphop', 'jazz', 'blues', 'country', 'classical', 'electronic', 'reggae', 'latin', 'metal', 'folk', 'indie', 'punk', 'rnb', 'soul', 'world', 'religious', 'kids', 'instrumental', 'other'], but got "upbeat". You need to change the input and try again.
Parameter type error: "type", expected one of ['rock', 'pop', 'hiphop', 'jazz', 'blues', 'country', 'classical', 'electronic', 'reggae', 'latin', 'metal', 'folk', 'indie', 'punk', 'rnb', 'soul', 'world', 'religious', 'kids', 'instrumental', 'other'], but got "romantic". You need to change the input and try again.
Parameter type error: "type", expected one of ['rock', 'pop

100%|███████████████████████████████████████████████████████████| 2261/2261 [00:07<00:00, 298.44it/s]

Missing required parameters: "courseCode". You need to change the input and try again.





In [34]:
len(cleaned_train)

1616

In [35]:
# json.dump(cleaned_train, open('../data/train_data_val_part_no_invalid_params_errors.json', 'w'))

Попробуем добавить дополнительный thought между ошибкой и ее исправлением, как в GPT4Tools

In [85]:
def parse_and_call_tool(prediction, tools):
    try:
        parsed = parser.parse(prediction)
        if isinstance(parsed, AgentFinish):
            return parsed.return_values['output']
        if parsed.tool in tools:
            tool = tools[parsed.tool]
            res = tool._run(parsed.tool_input)
        else:
            tool = CustomInvalidTool()
            res = tool._run(parsed.tool_input, tools)
        return res
    except Exception as e:
        return str(e)

def check_error(response):
    if 'Parameter type error' in response or 'Missing required parameters' in response or 'Could not parse LLM output' in response:
        return True
    return False

def check_error_w_type(response):
    if 'Parameter type error' in response:
        return ('Parameter type error', True)
    elif 'Missing required parameters' in response:
        return ('Missing required parameters', True)
    elif 'Could not parse LLM output' in response:
        return ('Could not parse LLM output', True)
    return (None, False)

def parse_query(meta_prompt):
    return meta_prompt.split('Begin!\n\nUSER: ')[1].split('\nASSISTANT Thought:')[0]

max_len = len(generations)
errors = []
for idx, (item_gt, item_pred) in enumerate(tqdm(zip(data[:max_len], generations))):
    query = parse_query(item_gt[0][0])
    tools = query_tool_mapping[query]

    flag = True
    for step_idx in range(len(item_gt[0])):
        if item_gt[1][step_idx]:
            res = parse_and_call_tool(item_gt[0][step_idx], tools)
            if check_error(res):
                flag = False

    if flag:
        for step_idx in range(len(item_gt[0])):
            if item_gt[1][step_idx]:
                res_pred = parse_and_call_tool(item_pred[0][step_idx], tools)
                res_gt = parse_and_call_tool(item_gt[0][step_idx], tools)
                error_status_pred = check_error_w_type(res_pred)
                error_status_gt = check_error_w_type(res_gt)
                if not error_status_gt[1] and error_status_pred[1]:
                    errors.append((item_gt, item_pred, error_status_pred[0], (step_idx, res_gt, res_pred)))

1676it [00:15, 110.83it/s]


Подсчитаем число ошибок каждого типа

In [86]:
error_types = []
for item in errors:
    error_types.append(item[-2])
Counter(error_types)

Counter({'Could not parse LLM output': 101,
         'Parameter type error': 76,
         'Missing required parameters': 20})

In [20]:
errors[1]

([['A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions with the help of some tools.\nYou have access to the following tools:\n\ngetDetails: If the user\'s question lacks the essential information needed to answer the question effectively, or if the question contains vague terms or pronouns without sufficient context, invoke the `getDetails` function to prompt the user for the missing critical details. However, `getDetails` should not be used in cases where the user omits optional parameters, unless these parameters become necessary in the course of the conversation. \nParameters: {"Question": "The question to prompt user to provide sufficient information."}\nOutput: User\'s response.\ngetCompetitions: Retrieve a list of all available competitions and their details.\nParameters: {}\nOutput: An array of objects, each representing a competition.\n - Format: application/json\n - Structur

In [88]:
correction_prompts = json.load(open('data/correction_thoughts.json', 'r'))

def construct_correction_example_w_thought(error):
    error_idx = error[-1][0]
    gt_chain = error[0][0]
    pred_chain = error[1][0]
    gt_mask_flags = error[0][1]
    error_pred = error[-1][-1]
    error_type = error[-2]
    correction_prompt = random.choice(correction_prompts[error_type])
    print(correction_prompt)
    correct_answer = correction_prompt + ' ' + gt_chain[error_idx]
    
    new_chain = gt_chain[:error_idx] + [pred_chain[error_idx], error_pred, correct_answer] + gt_chain[error_idx+1:]
    new_mask_flags = gt_mask_flags[:error_idx] + [False, False] + gt_mask_flags[error_idx:]
    return [new_chain, new_mask_flags]

In [89]:
correction_examples = []
for item in errors:
    correction_examples.append(construct_correction_example(item))

In [90]:
# json.dump(correction_examples, open('../data/correction_examples_w_internal_thoughts_val_model_trained_on_half_train.json', 'w'))

In [91]:
construct_correction_example_w_thought(errors[0])

The LLM output parsing failed, likely due to errors in the generation. Consider regenerating or altering the input.


[['A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions with the help of some tools.\nYou have access to the following tools:\n\ngetDetails: If the user\'s question lacks the essential information needed to answer the question effectively, or if the question contains vague terms or pronouns without sufficient context, invoke the `getDetails` function to prompt the user for the missing critical details. However, `getDetails` should not be used in cases where the user omits optional parameters, unless these parameters become necessary in the course of the conversation. \nParameters: {"Question": "The question to prompt user to provide sufficient information."}\nOutput: User\'s response.\nsearch: Search for internet assets based on various parameters\nParameters: {"query": "Required. string. The search query. Can be a domain name, IP address, SSL certificate information, or other paramete

In [None]:
def call_tool(generation, name_to_tool_map):
    agent_action = parser.parse(generation)
    if isinstance(agent_action, AgentFinish):
        return agent_action
    if agent_action.tool in name_to_tool_map:
        tool = name_to_tool_map[agent_action.tool]
        return_direct = tool.return_direct
        observation = tool.run(
            agent_action.tool_input,
            verbose=self.verbose,
            color=color,
            **tool_run_kwargs,
        )




        result = []
        for agent_action in actions:
            self.callback_manager.on_agent_action(
                agent_action, verbose=self.verbose, color="green"
            )
            # Otherwise we lookup the tool
            if agent_action.tool in name_to_tool_map:
                tool = name_to_tool_map[agent_action.tool]
                return_direct = tool.return_direct
                color = color_mapping[agent_action.tool]
                tool_run_kwargs = self.agent.tool_run_logging_kwargs()
                # =============================== modify ===============================
                # give GetDetailsTool more kwargs
                tool_run_kwargs["inputs"] = inputs
                # =============================== modify ===============================
                if return_direct:
                    tool_run_kwargs["llm_prefix"] = ""
                # We then call the tool on the tool input to get an observation
                observation = tool.run(
                    agent_action.tool_input,
                    verbose=self.verbose,
                    color=color,
                    **tool_run_kwargs,
                )
                # for testing
                # observation = '{"status code": 200}'
            else:
                tool_run_kwargs = self.agent.tool_run_logging_kwargs()
                observation = CustomInvalidTool().run(
                    agent_action.tool,
                    all_tools = list(name_to_tool_map.keys()),
                    verbose=self.verbose,
                    color=None,
                    **tool_run_kwargs,
                )
            result.append((agent_action, observation))
            return result

In [21]:
message = "'Action Input' cannot be a list. Only call one function per action."
message = "Status Code: 400. Response: {\"message\": \"" + message + "\""
message += ". You should choose one of: (1) change the input and retry; (2) return the 'Final Answer' and explain what happened; (You must choose this one when the error occurs more than 3 times.) (3) call another function."

In [22]:
message

'Status Code: 400. Response: {"message": "\'Action Input\' cannot be a list. Only call one function per action.". You should choose one of: (1) change the input and retry; (2) return the \'Final Answer\' and explain what happened; (You must choose this one when the error occurs more than 3 times.) (3) call another function.'