<a href="https://colab.research.google.com/github/vieenrose/Conformer-Transducer/blob/master/Bleat_Llama_2_function_calling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Bleat 🦙😯
## LLaMA 2 Function Calling

Bleat allows you to enable function calling in LLaMA 2 is a similar fashion to OpenAI's implementation for ChatGPT

In [None]:
# @title Install requirements
!nvidia-smi
!pip install transformers peft colorama bitsandbytes

In [None]:
# @title Load models and code (this takes a while)
import torch
import json
from transformers import AutoTokenizer, StoppingCriteria, GenerationConfig
from peft import AutoPeftModelForCausalLM
import colorama

base_model = AutoPeftModelForCausalLM.from_pretrained(
    "IfanSnek/bleat-adapter",
    torch_dtype=torch.float16,
    device_map="auto",
    offload_folder="./"
)
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token


class Streamer:
    def __init__(self, enabled):
        self.do_stream = enabled

    def put(self, input_ids):
        if self.do_stream:
            print(tokenizer.decode(input_ids[0]))

    def end(self):
        return


class EndOfFunctionCriteria(StoppingCriteria):
    """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""

    def __init__(self, start_length, eof_strings):
        self.start_length = start_length
        self.eof_strings = eof_strings
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs):
        """Returns true if all generated sequences contain any of the end-of-function strings."""
        decoded_generation = self.tokenizer.decode(input_ids[0])
        new_generation = decoded_generation[self.start_length:]
        return any(
            [
                eof_string in new_generation[len(eof_string):]
                for eof_string in self.eof_strings
            ]
        )


config = GenerationConfig.from_model_config(base_model.config)
config.max_length = 4096
config.max_new_tokens = 128


def infer(text):
    """
    Infer a function call from a prompt
    :param text: The prompt to infer from. Should not end with a header (eg, "### System: Hi!\n" but not "### System:")
    :return: The new text, and the header it starts with
    """

    headers = ["### System:", "### User:", "### Call:", "### Return:", "### Assistant:"]
    ids = base_model.generate(
        input_ids=tokenizer.encode(text, return_tensors="pt").to(base_model.device),
        stopping_criteria=[EndOfFunctionCriteria(len(text), headers)],
        generation_config=config,
    )
    decoded_generation = tokenizer.decode(ids[0])
    new_generation = decoded_generation.replace(text, "").replace("<s>", "").replace("<\\s>", "").strip("\n").strip()

    # Find the header the response starts with
    start_header = ""
    for header in headers:
        if new_generation.startswith(header):
            start_header = header
            break

    if start_header != "":
        new_generation = new_generation[len(start_header):]

    # Find the header the response ends with
    for header in headers:
        if new_generation.endswith(header):
            new_generation = new_generation[:-len(header)]

    return new_generation, start_header


functions = []
previous_functions = []


def register_function(json_definition, fn):
    functions.append((json_definition, fn))


running_chat = ""


def query(prompt):
    """
    Main function for querying the assistant
    :param prompt: The prompt to query with
    """
    global running_chat
    global functions
    global previous_functions

    if prompt == "reset":
        running_chat = ""
        previous_functions = []
        return "Resetting chat"

    prompt = "### User: " + prompt + "\n"

    # If functions changed, create a new system text
    if functions != previous_functions:
        system_text = "### System:\n" + json.dumps([f[0] for f in functions], indent=4) + "\n"
        running_chat += system_text
        previous_functions = functions

    # Generate the next response
    output, role = infer(running_chat + prompt)

    # If the response is a function call, try to call it
    if role == "### Call:":
        try:
            call = json.loads(output)
            if "name" not in call or "parameters" not in call:
                raise ValueError
        except ValueError:
            print(colorama.Fore.RED + "Error parsing call: " + output + colorama.Style.RESET_ALL)
            response = "Assistant: The call I wrote was formatted wrong. Please try again.\n"
            running_chat += response
            return response

        # Check if the function exists
        if call["name"] not in [f[0]["name"] for f in functions]:
            print(colorama.Fore.RED + "Error: function " + call["name"] + " does not exist" + colorama.Style.RESET_ALL)
            response = "### Assistant: The function I wrote does not exist. Please try again.\n"
            running_chat += response
            return response

        # Try to call the function
        try:
            fn = [f[1] for f in functions if f[0]["name"] == call["name"]][0]
            params = call["parameters"].values()

            fn_output = str(fn(*params))
            params_str = ", ".join([str(p) for p in params])
            print(f"{colorama.Fore.GREEN}{call['name']}({params_str}) = {fn_output}{colorama.Style.RESET_ALL}")

            fn_output = {
                "result": fn_output,
            }
            fn_output = json.dumps(fn_output)

        except Exception as e:
            print(colorama.Fore.RED + "Error calling function: " + str(e) + "\n" + str(call) + colorama.Style.RESET_ALL)
            response = "### Assistant: The function tried to call had an error. Please try again.\n"
            running_chat += response
            return response

        # Feed the return value back into the model
        running_chat += "### Call:\n" + json.dumps(call) + "\n"
        running_chat += "### Return:\n" + fn_output + "\n"
        output, role = infer(running_chat)

    # Add the response to the running chat
    running_chat += "### Assistant:" + output + "\n"

    return output.strip()

To start, define a function in the cell below:

In [None]:
def sunset(city):
  return "6:00 PM"

Once you have run the cell, create a definition in JSON like so:


```
definition = {
  "name": "function_name",
  "description": "Function Description",
  "parameters": [
      "parameter": {
          "type": "python type (str, int, etc.)",
          "description": "Parameter description"
      },
  ]
  "required": ["list", "of", "required", "parameters"]
}
```



In [None]:
definition = {
        "name": "get_sunset_time",
        "description": "Get the time of sunset in a city",
        "parameters": [
            {
                "name": "city",
                "type": "string"
            },
        ],
        "required": ["city"]
    }

Now, register the function.

In [None]:
# Clear old functions if applicable
functions = []
previous_functions = []

register_function(definition, sunset)

Send a message to LLaMA:

In [None]:
query("reset") # Forgets everything
print(query("When does the sun set in Bostonr?"))

[32mget_sunset_time(Denver) = 6:00 PM[0m


Have a whole heckin' conversation:

In [None]:
query("reset")
while True:
  print(query(input("> ")))