In [2]:
from datasets import load_dataset

fc = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn", split="train")

# Split into train, validation, test 8:1:1
train = fc.train_test_split(test_size=0.2, seed=42)['train']
val_test = fc.train_test_split(test_size=0.5, seed=42)['test']
val = val_test.train_test_split(test_size=0.5, seed=42)['train']
test = val_test.train_test_split(test_size=0.5, seed=42)['test']

train.push_to_hub("riczhou/hermes-function-calling-v1-split", split="train")
val.push_to_hub("riczhou/hermes-function-calling-v1-split", split="validation")
test.push_to_hub("riczhou/hermes-function-calling-v1-split", split="test")

Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 37.23ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.75it/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 48.29ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  2.27it/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 45.98ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  2.25it/s]


CommitInfo(commit_url='https://huggingface.co/datasets/riczhou/hermes-function-calling-v1-split/commit/1540c59cfd30988079f374785ce44b776b00cff0', commit_message='Upload dataset', commit_description='', oid='1540c59cfd30988079f374785ce44b776b00cff0', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/riczhou/hermes-function-calling-v1-split', endpoint='https://huggingface.co', repo_type='dataset', repo_id='riczhou/hermes-function-calling-v1-split'), pr_revision=None, pr_num=None)

In [8]:
# Show first test example
print(test[1])

{'task': 'Trigger Notifications with a POST Request', 'conversations': [{'from': 'system', 'value': "You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.\n<tools>\n[{'type': 'function', 'function': {'name': 'configureJiraNotificationWorkflow', 'description': 'Sets up an automated workflow to send notifications via POST request when specified events occur in Jira.', 'parameters': {'type': 'object', 'properties': {'issue_tracking_system': {'type': 'string', 'description': 'The issue tracking system to monitor for events.'}, 'notification_endpoint': {'type': 'string', 'description': 'The endpoint URL to send the POST request for notifications.'}, 'event_types': {'type': 'array', 'description': 'List of event types to trigger notifications.', 'items': {'type': 'string'}}}, 'required': ['issue_tracking_

In [11]:
def split_sample(sample):
    assert len(sample["conversations"]) == 3

    system = None
    human = None
    gpt = None

    for turns in sample["conversations"]:
        if turns["from"] == "system":
            system = turns["value"]
        elif turns["from"] == "human":
            human = turns["value"]
        elif turns["from"] == "gpt":
            gpt = turns["value"]

    return {
        "system": system,
        "human": human,
        "gpt": gpt
    }

sample = split_sample(test[1])

In [42]:
import ast
import json
from typing import List, Optional, Dict, Union
import re


def clean_json_string(s: str) -> str:
    s = s.replace('\\\\n', '').replace('\\n', '')
    return ' '.join(s.split())


def parse_completion(gpt: str) -> Optional[List[Dict[str, Union[str, dict]]]]:
    try:
        if not isinstance(gpt, str):
            raise ValueError(f"Input must be a string, got {type(gpt)}")
        
        if not gpt.strip():
            return None
            
        pattern = r"<tool_call>(.*?)</tool_call>"
        matches = re.findall(pattern, gpt, re.DOTALL)
        
        if not matches:
            return None
            
        valid_tools = []
        
        for potential_json in matches:
            try:
                cleaned_json = clean_json_string(potential_json)
                
                if not cleaned_json:
                    continue
                
                # First try to parse as regular JSON
                try:
                    tool_call_json = json.loads(cleaned_json)
                except json.JSONDecodeError:
                    # If JSON parsing fails, try converting single quotes to double quotes
                    try:
                        # Use ast.literal_eval to safely evaluate the string as a Python literal
                        tool_call_json = ast.literal_eval(cleaned_json)
                    except (SyntaxError, ValueError) as e:
                        # If both methods fail, try replacing single quotes with double quotes
                        # but only for the outermost quotes
                        if cleaned_json.startswith("'") and cleaned_json.endswith("'"):
                            cleaned_json = cleaned_json[1:-1]  # Remove outer quotes
                            cleaned_json = f'"{cleaned_json}"'  # Add double quotes
                            try:
                                tool_call_json = json.loads(cleaned_json)
                            except:
                                continue
                        else:
                            continue
                
                if not isinstance(tool_call_json, dict):
                    continue

                restructured_tool = restructure_tool_call(tool_call_json)
                
                if restructured_tool['name'] is not None:
                    valid_tools.append(restructured_tool)
                
            except Exception:
                continue
        
        return valid_tools if valid_tools else None
        
    except Exception as e:
        return None
    

def restructure_tool_call(data: Dict) -> Dict:
    result = {'name': None, 'arguments': {}}
    
    def extract_fields(d):
        for k, v in d.items():
            if k == 'name' and result['name'] is None:
                result['name'] = v
            elif k == 'arguments' and not result['arguments']:
                result['arguments'] = v
            elif isinstance(v, dict):
                extract_fields(v)
    
    extract_fields(data)
    
    if result['name'] is None and 'name' in result['arguments']:
        result['name'] = result['arguments'].pop('name')
    
    if not result['arguments']:
        result['arguments'] = {k: v for k, v in data.items() 
                             if k != 'name' and k != 'arguments'}
    
    return result


def validate_tool_calls(tool_calls: List[Dict]) -> List[Dict]:
    valid_calls = []
    
    for tool_call in tool_calls:
        if not isinstance(tool_call, dict):
            continue
            
        required_fields = {'name', 'arguments'}
        if not all(field in tool_call for field in required_fields):
            continue
            
        if not isinstance(tool_call['name'], str):
            continue
            
        if not isinstance(tool_call['arguments'], dict):
            continue
            
        valid_calls.append(tool_call)
        
    return valid_calls


completions = parse_completion(sample["gpt"])
print(completions)
completions = validate_tool_calls(completions)
completions

[{'name': 'gather_social_media_data', 'arguments': {'platforms': ['Facebook', 'Twitter', 'Instagram'], 'metrics': ['likes', 'comments', 'shares', 'new followers'], 'time_period': 'weekly'}}, {'name': 'optimize_google_ads_bids', 'arguments': {'campaign_id': '12345', 'optimization_strategy': 'return_on_ad_spend', 'target_roas': 4.0}}, {'name': 'generate_seo_reports', 'arguments': {'domains': ['trendytech.com'], 'report_type': 'link_building', 'frequency': 'monthly'}}]


[{'name': 'gather_social_media_data',
  'arguments': {'platforms': ['Facebook', 'Twitter', 'Instagram'],
   'metrics': ['likes', 'comments', 'shares', 'new followers'],
   'time_period': 'weekly'}},
 {'name': 'optimize_google_ads_bids',
  'arguments': {'campaign_id': '12345',
   'optimization_strategy': 'return_on_ad_spend',
   'target_roas': 4.0}},
 {'name': 'generate_seo_reports',
  'arguments': {'domains': ['trendytech.com'],
   'report_type': 'link_building',
   'frequency': 'monthly'}}]

In [43]:
# verify how many samples in the test set have valid completions

valid_completions = 0

for sample in test:
    sample = split_sample(sample)
    completions = parse_completion(sample["gpt"])
    if completions is None:
        continue
    completions = validate_tool_calls(completions)
    if completions:
        valid_completions += 1

print(f"{valid_completions} out of {len(test)} samples have valid completions")

472 out of 474 samples have valid completions
