<center>
  <p style="text-align:center">
    Automating tool/function calling workflow
    <br>
    <a href="https://github.com/synacktraa/tool-parse">Github</a>
    |
    <a href="https://pypi.org/project/tool-parse">PyPI</a>
    |
    <a href="https://colab.research.google.com/drive/1C2WCgIZ7LnkpLt3KARL9ROh4iLwaACa6?usp=sharing">Colab</a>
  </p>
</center>

In [None]:
#@title Install and import necessary libraries
%pip install -qU tool-parse[pydantic] duckduckgo_search
%pip install -qU sentencepiece accelerate bitsandbytes

import json
import torch
from tool_parse import ToolRegistry
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
)

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.8/2.8 MB[0m [31m41.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.5/137.5 MB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h

### Creating a registry and registering tools

In [None]:
from tool_parse import ToolRegistry

from typing import Literal
from pydantic import BaseModel
from duckduckgo_search import DDGS

tr = ToolRegistry()
ddgs = DDGS()

@tr.register
def search_text(
  text: str,
  *, # Yes it supports positional and keyword arguments
  safe_search: bool = True,
  backend: Literal['api', 'html', 'lite'] = 'api',
  max_results: int = 1
):
    """
    Search for text in the web.
    :param text: Text to search for.
    :param safe_search: If True, enable safe search.
    :param backend: Backend to use for retrieving results.
    :param max_results: Max results to return.
    """
    return ddgs.text(
        keywords=text,
        safesearch="on" if safe_search else "off",
        backend=backend,
        max_results=max_results
    )

@tr.register(name='product_information')
class ProductInfo(BaseModel): # This can be TypedDict or NamedTuple if you don't want to install pydantic package.
    """
    Information about the product.
    :param name: Name of the product.
    :param price: Price of the product.
    :param in_stock: If the product is in stock.
    """
    name: str
    price: float
    in_stock: bool

@tr.register
def get_translation(
    text: str,
    to: Literal['en', 'ja', 'hi', 'es', 'fr', 'de', 'zh'] = 'en'
):
    """
    Translate the given text.
    :param text: Text to translate.
    :param to: what language to translate.
    """
    return ddgs.translate(keywords=text, to=to)

In [None]:
"""
Lets take a look at each tool schema, we can get a particular tool schema from the registry like we're accessing a dictionary object with a key
"""
tr['search_text']

{'type': 'function',
 'function': {'name': 'search_text',
  'description': 'Search for text in the web.',
  'parameters': {'type': 'object',
   'properties': {'text': {'type': 'string',
     'description': 'Text to search for.'},
    'safe_search': {'type': 'boolean',
     'description': 'If True, enable safe search.'},
    'backend': {'enum': ['api', 'html', 'lite'],
     'type': 'string',
     'description': 'Backend to use for retrieving results.'},
    'max_results': {'type': 'integer',
     'description': 'Max results to return.'}},
   'required': ['text']}}}

In [None]:
print('ProductInfo' in tr) # Original name is ProductInfo but we gave it a custom name so it will return False
tr['product_information']

False


{'type': 'function',
 'function': {'name': 'product_information',
  'description': 'Information about the product.',
  'parameters': {'type': 'object',
   'properties': {'name': {'type': 'string',
     'description': 'Name of the product.'},
    'price': {'type': 'number', 'description': 'Price of the product.'},
    'in_stock': {'type': 'boolean',
     'description': 'If the product is in stock.'}},
   'required': ['name', 'price', 'in_stock']}}}

In [None]:
tr['get_translation']

{'type': 'function',
 'function': {'name': 'get_translation',
  'description': 'Translate the given text.',
  'parameters': {'type': 'object',
   'properties': {'text': {'type': 'string',
     'description': 'Text to translate.'},
    'to': {'enum': ['en', 'ja', 'hi', 'es', 'fr', 'de', 'zh'],
     'type': 'string',
     'description': 'what language to translate.'}},
   'required': ['text']}}}

### Setting up Gorilla pipeline

In [None]:
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Quantization Setup
# This configuration compresses the model and makes it possible to use with Colab's GPU
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype
)

# Model and tokenizer setup
model_id : str = "gorilla-llm/gorilla-openfunctions-v2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    quantization_config=bnb_config,
    low_cpu_mem_usage=True,
    device_map="auto"
)

# Create pipeline
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=128,
    batch_size=16,
    torch_dtype=torch_dtype,
)

# This function will return Gorilla's response
def get_gorilla_expression(query: str, functions: list[dict]) -> str:
    output = pipe(
        f"USER: <<question>> {query} <<function>> {json.dumps(functions)}\nASSISTANT: "
    )[0]['generated_text'].splitlines()[-1]
    return output[output.index(":")+1:].strip().strip('<<function>>')

tokenizer_config.json:   0%|          | 0.00/4.24k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/4.61M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/462 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/669 [00:00<?, ?B/s]

pytorch_model.bin.index.json:   0%|          | 0.00/22.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.97G [00:00<?, ?B/s]

pytorch_model-00002-of-00002.bin:   0%|          | 0.00/3.85G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/121 [00:00<?, ?B/s]

In [None]:
# Getting schema of registered tools

tools = tr.marshal('base')
tools

[{'type': 'function',
  'function': {'name': 'search_text',
   'description': 'Search for text in the web.',
   'parameters': {'type': 'object',
    'properties': {'text': {'type': 'string',
      'description': 'Text to search for.'},
     'safe_search': {'type': 'boolean',
      'description': 'If True, enable safe search.'},
     'backend': {'enum': ['api', 'html', 'lite'],
      'type': 'string',
      'description': 'Backend to use for retrieving results.'},
     'max_results': {'type': 'integer',
      'description': 'Max results to return.'}},
    'required': ['text']}}},
 {'type': 'function',
  'function': {'name': 'product_information',
   'description': 'Information about the product.',
   'parameters': {'type': 'object',
    'properties': {'name': {'type': 'string',
      'description': 'Name of the product.'},
     'price': {'type': 'number', 'description': 'Price of the product.'},
     'in_stock': {'type': 'boolean',
      'description': 'If the product is in stock.'}},
 

In [None]:
# Search text
expression = get_gorilla_expression(
    query="Search for gorilla LLM benchmarks", functions=tools
)
print(f"{expression=}")

# Compile the expression to get the output
output = tr.compile(expression)
print(f"{output=}")

expression="search_text(text='gorilla')"
output=[{'title': 'Gorilla - Wikipedia', 'href': 'https://en.wikipedia.org/wiki/Gorilla', 'body': 'The word gorilla comes from the history of Hanno the Navigator (c. 500 BC), a Carthaginian explorer on an expedition to the west African coast to the area that later became Sierra Leone. [1] [2] Members of the expedition encountered "savage people, the greater part of whom were women, whose bodies were hairy, and whom our interpreters called Gorillae".[3] [4] It is unknown whether what the ...'}]


In [None]:
# Extracting product information
expression = get_gorilla_expression(
    query="Parse: Product RTX 4900, priced at $3.5k, is in stock.", functions=tools
)
print(f"{expression=}")

# Compile the expression to get the output
output = tr.compile(expression)
print(f"{output=}") # function/tool-calling is nothing but getting Structured output :>

expression="product_information(name='RTX 4900', price=3500, in_stock=True)"
output=ProductInfo(name='RTX 4900', price=3500.0, in_stock=True)


In [None]:
# Translation
expression = get_gorilla_expression(
    query="'Tool calling is one of the best features of LLM' to hindi. Please utilize tool.", functions=tools
) # It keeps translating the text itself, so I had to be specific regarding tool usage.
print(f"{expression=}")

# Compile the expression to get the output
output = tr.compile(expression)
print(f"{output=}")

expression="get_translation(text='Tool calling is one of the best features of LLM', to='hi')"
output=[{'detected_language': 'en', 'translated': 'टूल कॉलिंग एलएलएम की सबसे अच्छी विशेषताओं में से एक है', 'original': 'Tool calling is one of the best features of LLM'}]
