In [7]:
import os,sys
sys.path.append(os.path.abspath(".")) 
sys.path.append(os.path.abspath("..")) 
import torch
import json
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from utils.utils import *
from utils.plot_utils import *
torch.set_grad_enabled(False)
from transformers import AutoTokenizer

seed_all()


In [8]:
device = 'cuda:0'
data_dir = '/dataset/common/huggingface/model'
torch_dtype = torch.bfloat16
max_ctx_len = 32768
model_path = 'Qwen/Qwen3-8B'
# model_path = os.path.join(data_dir,'Qwen3-8B-Toolace_ASIDE')
# model,tokenizer,is_aside = load_model(model_path,use_vllm=True,dtype=torch_dtype,vllm_kwargs = {'gpu_memory_utilization':0.8,'enable_chunked_prefill':False},max_ctx_len=max_ctx_len)
tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)


In [9]:
ds = {}
with open("/export/home2/weijie210/ToolBench/ToolBench_data/data/toolllama_G123_dfs_train.json","r") as f:
    ds['train'] = json.load(f)

with open("/export/home2/weijie210/ToolBench/ToolBench_data/data/toolllama_G123_dfs_eval.json","r") as f:
    ds['eval'] = json.load(f)

In [None]:
tool_ds = {}
non_tool_ds = {}
for ds_key, ds_ in ds.items():
    tool_ds[ds_key] = [d for d in ds_ if any([t['from'] == 'function' for t in d['conversations']])]
    non_tool_ds[ds_key] = [d for d in ds_ if not any([t['from'] == 'function' for t in d['conversations']])]
    print (f'{ds_key}: {len(tool_ds[ds_key])} samples with tool use, {len(non_tool_ds[ds_key])} samples without tool use')

train: 149982 samples with tool use, 37560 samples without tool use
eval: 609 samples with tool use, 153 samples without tool use


In [5]:
import ast,copy

def parse_assistant_response(response):
    action_input = None
    action = None
    thought = None
    if "Action Input:" in response:
        action_input = response.split("Action Input:")[-1].strip().replace('\n',' ')
        try:
            action_input = json.loads(action_input)
        except:
            try:
                action_input = ast.literal_eval(action_input)
            except:
                action_input = None
        if action_input is not None:
            if "Action:" in response:
                action = response.split("Action:")[-1].split("Action Input:")[0].strip()
                thought = response.split("Thought:")[-1].split("Action:")[0].strip()
    elif "Action:" in response:
        action = response.split("Action:")[-1].strip()
        thought = response.split("Thought:")[-1].split("Action:")[0].strip()
    return thought, action, action_input

def convert_turn_keys(messages):
    messages = copy.deepcopy(messages)
    for m in messages:
        m['role'] = m.pop('from')
        m['content'] = m.pop('value')
        if m['role'] == 'function':
            m['role'] = 'tool'
    return messages

def format_assistant_response(action,action_input):
    out_format = "<tool_call>\n{tool_call}\n</tool_call>"
    tool_call = json.dumps({
        'name': action,
        'arguments': action_input
    })
    return out_format.format(tool_call = tool_call)

def process_tool_response(resp):
    str_resp =  resp.replace('<tool_call>','').replace('</tool_call>','').strip()
    try:
        tool_call = json.loads(str_resp)
    except:
        try:
            tool_call = ast.literal_eval(str_resp)
        except:
            tool_call = None
    return tool_call


In [6]:
filtered_tool_ds = defaultdict(list) # for tool use, randomly select one valid assistant turn with tool use
for ds_key, ds_ in tool_ds.items():
    for sample in tqdm(ds_,total = len(ds_)):
        sample = copy.deepcopy(sample) 
        conversations = sample['conversations'] 
        valid_assistant_turns = []
        start_tracking = False
        for i,turn in enumerate(conversations):
            if turn['from'] == 'system':
                system_prompt = turn['value'] # get the tools
                tools = system_prompt.split('you have access to the following APIs:')[-1].strip()
                tools = ast.literal_eval(tools)
                if not isinstance(tools, list) and not all([isinstance(t, dict) for t in tools]): # cant parse properly.
                    break
                tools = [t for t in tools if t['name'].lower() != 'finish']
            if turn['from'] == 'function':
                start_tracking = True
            if turn['from'] == 'assistant':
                thought, action, action_input = parse_assistant_response(turn['value'])
                if thought is None or action is None or action_input is None: # cant parse assistant resp
                    break
                turn['value'] = format_assistant_response(action,action_input)
                if action.lower().strip() =='finish': # not a tool, just the string
                    if 'final_answer' not in action_input: # end of turn fail since no final answer
                        break
                    turn['value'] = action_input['final_answer']
                if start_tracking:
                    valid_assistant_turns.append(conversations[1:i+1])
            # we select each assistant turn (if there exist a tool use before) as a valid training sample, 
            # Then random select one 
        if len(valid_assistant_turns): 
            random_to_generate_id = np.random.choice(range(len(valid_assistant_turns)))
            random_to_generate = convert_turn_keys(valid_assistant_turns[random_to_generate_id])
            filtered_tool_ds[ds_key].append({'conversations':random_to_generate,'tools':tools}) # remove the system prompt

    print (f'{ds_key}: {len(filtered_tool_ds[ds_key])} samples for on-policy generation')


            
            

100%|██████████| 149982/149982 [01:10<00:00, 2133.49it/s]


train: 148007 samples for on-policy generation


100%|██████████| 609/609 [00:00<00:00, 2638.64it/s]

eval: 600 samples for on-policy generation





In [12]:
def tool_prompt_format(messages,tools,tokenizer,encode=False,enable_thinking=False): # for tool-use agents like Qwen
    if messages[-1]['role'] == 'assistant':
        add_generation_prompt = True
    else:
        add_generation_prompt = False
    if 'qwen' in tokenizer.name_or_path.lower():
        formatted = tokenizer.apply_chat_template(messages,add_generation_prompt=add_generation_prompt,tokenize=False,enable_thinking=enable_thinking,tools=tools)
    else:
        formatted = tokenizer.apply_chat_template(messages,add_generation_prompt=add_generation_prompt,tokenize=False,tools=tools)
    return encode_fn(formatted,tokenizer) if encode else formatted

In [20]:
def check_valid(sample):
    conversations = sample['conversations']
    turns = [t['role'] for t in conversations][2:]
    ## check that there is at least one turn after each tool turn
    for i,role in enumerate(turns):
        if role == 'tool':
            if i == len(turns) - 1 or turns[i+1] not in ['assistant','user']:
                return True
    return False

invalid_count = {k: sum([check_valid(s) for s in v]) for k,v in filtered_tool_ds.items()}
print ("Invalid counts (tool ds): ", invalid_count) # should be 0

Invalid counts (tool ds):  {'train': 0, 'eval': 0}


In [None]:
# filtered next by length
MAX_LENGTH = 4096
def compute_length(sample):
    text = tool_prompt_format(
        sample["conversations"],
        tools=sample["tools"],
        tokenizer=tokenizer,
        enable_thinking=False
    )
    length = len(tokenizer.encode(text, add_special_tokens=True))
    sample["input_len"] = length
    return sample

MAX_LENGTH = 4096 - 512
filtered_tool_ds = {k:async_process(compute_length, filtered_tool_ds[k], workers=32, msg="Computing input lengths") for k in filtered_tool_ds.keys()}
    
filtered_tool_ds = {k: [d for d in filtered_tool_ds[k] if d['input_len'] < MAX_LENGTH] for k in filtered_tool_ds.keys()}

print (f'After filtering by length < {MAX_LENGTH}, we have {len(filtered_tool_ds["train"])} train samples and {len(filtered_tool_ds["eval"])} val samples')

Computing input lengths:  18%|█▊        | 26622/148007 [00:33<02:33, 792.00it/s] 


In [27]:
filtered_train_ds = [{'conversations':d['input'],'tools':d['tools']} for d in filtered_train_ds]
filtered_val_ds = [{'conversations':d['input'],'tools':d['tools']} for d in filtered_val_ds]

In [None]:
with open('aside/data/train_data/toolbench_train_4096.json','w') as f:
    json.dump(filtered_train_ds,f,indent=4)

with open('aside/data/train_data/toolbench_val_4096.json','w') as f:
    json.dump(filtered_val_ds,f,indent=4)