# Function Calling with the new Gorilla model

Copied and ajdusted from https://github.com/abacaj/openhermes-function-calling/tree/main
Model: gorilla-llm/gorilla-openfunctions-v1
Dataset: https://github.com/ShishirPatil/gorilla/tree/main/openfunctions

In [None]:
!pip install transformers sentence-transformers

In [38]:
import warnings

warnings.filterwarnings("ignore")

In [1]:
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import CrossEncoder

model_id = "gorilla-llm/gorilla-openfunctions-v0"
ranker_id = "cross-encoder/ms-marco-MiniLM-L-12-v2"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16,device_map="auto")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Ranks functions descriptions based on their similarity to the query
def rank_pipeline(model_id):
  rank_model = CrossEncoder(model_id)
  def rank(query, functions):
    scores = rank_model.predict([(query, doc["description"]) for doc in functions])
    return functions[scores.argmax()]
  return rank 

ranker = rank_pipeline(ranker_id)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [25]:
import json
def generate_functions_prompt(query,functions=None):
    prompt = f"""USER: <<question>> {query} <<function>> {json.dumps(functions)}\nASSISTANT: """
    return  prompt

def generate(prompt):
    return pipe(generate_functions_prompt(prompt),max_new_tokens=512,do_sample=True,return_full_text=False)[0]['generated_text']

def generate_ranked(prompt,functions):
    fn = ranker(prompt,functions)
    print(f"using: {fn['name']}")
    prompt = generate_functions_prompt(prompt,[fn])
    return generate(prompt)

In [41]:
functions = [
    {
        "name": "call_uber",
        "description": "Find suitable ride for customers given the location, type of ride, and the amount of time the customer is willing to wait as parameters",
         "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "location of the starting place of the uber ride",
                },
                "ride_type": {
                    "type": "string",
                    "enum":  ["plus", "comfort", "black"],
                    "description": "types of uber ride user is ordering",
                },
                "time": {
                    "type": "number",
                    "description": "the amount of time in minutes the customer is willing to wait",
                },
                "required": ["location", "time"],

            },
         },
    },
    {
        "name": "get_current_weather",
        "description": "Gets the current weather for a given location",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The city and state, e.g. San Francisco, CA",
                },
                "format": {
                    "type": "string",
                    "enum": ["celsius", "fahrenheit"],
                    "description": "The temperature unit to use. Infer this from the users location.",
                },
            },
            "required": ["location"],
        },
    },
]

In [32]:
ranker("What is the weather in New York?",functions)

{'name': 'get_current_weather',
 'description': 'Gets the current weather for a given location',
 'parameters': {'type': 'object',
  'properties': {'location': {'type': 'string',
    'description': 'The city and state, e.g. San Francisco, CA'},
   'format': {'type': 'string',
    'enum': ['celsius', 'fahrenheit'],
    'description': 'The temperature unit to use. Infer this from the users location.'}},
  'required': ['location']}}

In [39]:
res = generate_ranked("What is the weather in New York?",functions=functions)
print(res.strip())


using: get_current_weather
get_current_weather(location="New York", format="celsius")


In [42]:
res = generate_ranked("Call me an Uber ride in Berkeley at zipcode 94704 in 10 minutes", functions=functions)
print(res.strip())

using: call_uber
call_uber(location="94704", ride_type="plus", time=10)
