In [None]:
import json
import yaml
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from copy import deepcopy
import random
random.seed(0)

In [None]:
train_data_original = json.load(open('data/toolbench_new_1311/cleaned_toolllama_G123_dfs_train_downloaded1311_no_undetectable_errors_final.json', 'r'))

In [None]:
train_data_original[0]

In [None]:
def get_tool_set(item):
    tool_system_msg = "You have access of the following tools:\n"
    api_system_msg = "\nSpecifically, you have access to the following APIs: "

    system_step = item['conversations'][0]['value']

    tool_descriptions = system_step.split(tool_system_msg)[1].split(api_system_msg)[0]
    tools = system_step.split(api_system_msg)[1]

    try:
        tools = yaml.load(tools, yaml.Loader)
        return tools
    except Exception as e:
        return []

In [None]:
train_data_parseble_tools = []
for item in tqdm(train_data_original):
    tools = get_tool_set(item)
    if tools:
        item_parseble_tools = deepcopy(item)
        item_parseble_tools['tools'] = tools
        train_data_parseble_tools.append(item_parseble_tools)

In [None]:
len(train_data_parseble_tools)

In [None]:
get_tool_set(train_data_original[0])

In [None]:
all_apis = set()
all_tools = set()
for item in tqdm(train_data_parseble_tools):
    curr_tools = item['tools']
    for tool in curr_tools:
        tool_name = tool['name']
        if 'for' in tool_name:
            all_tools.add(tool_name.split('_for_')[1])
        else:
            all_tools.add(tool_name)
        all_apis.add(tool['name'])

In [None]:
def get_tool_names(item):
    tools = item['tools']
    tool_names = set()
    for tool in tools:
        tool_name = tool['name']
        if 'for' in tool_name:
            tool_names.add(tool_name.split('_for_')[1])
        else:
            tool_names.add(tool_name)
    return tool_names

In [None]:
train_tools, val_tools = train_test_split(list(all_tools), test_size=0.4)
train_tools = set(train_tools)
val_tools = set(val_tools)
if 'Finish' in val_tools:
    val_tools.remove('Finish')
    train_tools.add('Finish')
print(len(train_tools), len(val_tools))

intersection_count = 0
for item in tqdm(train_data_parseble_tools):
    tools = get_tool_names(item)
    tools.remove('Finish')
    tools_from_train = tools.intersection(train_tools)
    tools_from_val = tools.intersection(val_tools)
    if tools_from_train and tools_from_val:
        rand = random.uniform(0,1)
        if rand >= 0.4:
            for tool in tools_from_val:
                val_tools.remove(tool)
                train_tools.add(tool)
        else:
            for tool in tools_from_train:
                train_tools.remove(tool)
                val_tools.add(tool)

print(len(train_tools), len(val_tools))

In [None]:
intersection_count = 0
train_count = 0
val_count = 0

train_sample = []
val_sample = []
for item in tqdm(train_data_parseble_tools):
    tools = get_tool_names(item)
    tools.remove('Finish')
    tools_from_train = tools.intersection(train_tools)
    tools_from_val = tools.intersection(val_tools)
    if tools_from_train:
        train_count += 1
    if tools_from_val:
        val_count += 1
    if tools_from_train and tools_from_val:
        intersection_count += 1

    if tools_from_train and not tools_from_val:
        train_sample.append(item)
    else:
        val_sample.append(item)
        
print(train_count, val_count, intersection_count)
print(len(train_sample))
print(len(val_sample))

In [None]:
print(len(train_data_original))
print(len(train_data_parseble_tools))
print(len(train_sample))
print(len(val_sample))
print(len(train_sample) + len(val_sample))

In [None]:
def split_data_last_tool_call(data):
    data_splitted = []
    for item in data:
        conversations = item['conversations']
        for i in range(len(conversations)):
            if conversations[i]['from'] == 'assistant':
                splitted_item = deepcopy(item)
                splitted_item['conversations'] = splitted_item['conversations'][:i+1]
                data_splitted.append(splitted_item)
    return data_splitted

In [None]:
train_original_splitted = split_data_last_tool_call(train_data_original)
len(train_original_splitted)

In [None]:
train_sample_splitted = split_data_last_tool_call(train_sample)
val_sample_splitted = split_data_last_tool_call(val_sample)

In [None]:
len(train_sample_splitted)

In [None]:
len(val_sample_splitted)

In [None]:
json.dump(train_sample_splitted, open('data/toolbench_new_1311/cleaned_toolllama_G123_dfs_train_downloaded1311_no_undetectable_errors_final_train_train.json', 'w'))

In [None]:
json.dump(val_sample_splitted, open('data/toolbench_new_1311/cleaned_toolllama_G123_dfs_train_downloaded1311_no_undetectable_errors_final_train_val.json', 'w'))

In [None]:
len(train_sample)

In [None]:
len(val_sample)

In [None]:
json.dump(train_sample, open('data/toolbench_new_1311/cleaned_toolllama_G123_dfs_train_downloaded1311_no_undetectable_errors_final_train_train_full_chains.json', 'w'))

In [None]:
json.dump(val_sample, open('data/toolbench_new_1311/cleaned_toolllama_G123_dfs_train_downloaded1311_no_undetectable_errors_final_train_val_full_chains.json', 'w'))