In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Union
from tqdm.auto import tqdm, trange

MODEL = "teknium/OpenHermes-2.5-Mistral-7B"

In [2]:
import json
import xml.etree.ElementTree as ET
import re


class OpenHermesInference:

    def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def __inference__(self, messages: List[Dict]):
        tokens = self.tokenizer.apply_chat_template(messages,
                                                    return_tensors="pt").to(
                                                        self.model.device)
        input_size = tokens.numel()
        print("Input Tokens: ", input_size)
        with torch.inference_mode():
            generated_tokens = self.model.generate(
                tokens,
                use_cache=True,
                do_sample=True,
                temperature=0.2,
                top_p=1.0,
                top_k=0,
                max_new_tokens=512,
                eos_token_id=self.tokenizer.eos_token_id,
                pad_token_id=self.tokenizer.eos_token_id,
            )
        print("Generated New Tokens: ",
              len(generated_tokens.squeeze()[input_size:]))
        return self.tokenizer.decode(generated_tokens.squeeze()[input_size:],
                                     skip_special_tokens=True)


class FunctionCall(OpenHermesInference):

    def __init__(self, model, tokenizer):
        super().__init__(self, FunctionCall)
        self.system_prompt = """You are a helpful assistant with access to the following functions:
        
            {functions}
        
            To use these functions respond with:
            <multiplefunctions>
                <functioncall> {{fn}} </functioncall>
                <functioncall> {{fn}} </functioncall>
                ...
            </multiplefunctions>
            
            Edge cases you must handle:
            - If there are no functions that match the user request, you will respond politely that you cannot help.<|im_end|>

            Refer the below provided output example for function calling
            Question: What's the weather difference in NY and LA?
            <multiplefunctions>
                <functioncall> {{"name": "getWeather", "parameters": {{"city": "NY"}}}} </functioncall>
                <functioncall> {{"name": "getWeather", "parameters": {{"city": "LA"}}}} </functioncall>
            </multiplefunctions>
            
        """

    def functionCall(self, messages: List[Dict], functions: List[Dict]):
        functions_texts = "\n\n".join(
            [f"{json.dumps(function)}" for function in functions])
        if messages[0].get("role") == "system":
            new_system_prompt = (
                self.system_prompt.format(functions=functions_texts) + "\n" +
                messages[0].get("content"))
            messages[0]["content"] = new_system_prompt
        else:
            messages = [{
                "role":
                "system",
                "content":
                self.system_prompt.format(functions=functions_texts),
            }] + messages
        output_text = self.__inference__(messages)
        return output_text


class NormalCompletion(OpenHermesInference):

    def __init__(self, model, tokenizer):
        super().__init__(self, NormalCompletion)

    def normalCompletion(self, messages: List[str]):
        output_text = self.__inference__(messages)
        return output_text


class FunctionExtraction:

    def __call__(self, text: str):
        completion = text.strip()
        pattern = r"(<multiplefunctions>(.*?)</multiplefunctions>)"
        match = re.search(pattern, completion, re.DOTALL)
        if not match:
            return None
        multiplefn = match.group(1)
        root = ET.fromstring(multiplefn)
        functions = root.findall("functioncall")
        return [json.loads(fn.text) for fn in functions]


class Completion(FunctionCall, NormalCompletion, FunctionExtraction):

    def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
        super().__init__(self, Completion)
        self.model = model
        self.tokenizer = tokenizer

    def chatCompletion(self,
                       messages: List[Dict],
                       functions: Union[None, List] = None):
        if functions:
            function_call_text = self.functionCall(messages, functions)
            functions = FunctionExtraction()(function_call_text)
            return functions
        else:
            return self.normalCompletion(messages)

In [5]:
completion_tokenizer = AutoTokenizer.from_pretrained(MODEL)
completion_model = AutoModelForCausalLM.from_pretrained(
    MODEL, torch_dtype=torch.bfloat16, device_map="auto"
).eval()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
completion_obj = Completion(completion_model, completion_tokenizer)

In [7]:
from pydantic import BaseModel, Field
from typing import List, Dict


class QueryDependency(BaseModel):
    id: int = Field(..., description="Unique Integer Id for the Query")
    question: str = Field(
        ...,
        description=
        "Question we want to ask to get a better context or more background about the main question.",
    )


class Dependencies(BaseModel):
    dependencies: List[QueryDependency] = Field(
        ...,
        description=
        "A list of query dependencies in the correct sequence to fetch more background information about the main question.",
    )


functions = [{
    "name": "dependencyPlanning",
    "description":
    "Plan a sequential list of all the sub-questions that once answered can provide more background to answer the main question.",
    "parameters": Dependencies.schema(),
}]

/tmp/ipykernel_484/580731889.py:26: PydanticDeprecatedSince20: The `schema` method is deprecated; use `model_json_schema` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
  "parameters": Dependencies.schema()


In [9]:
DEPENDENCY_PROMPT = """You're a ChatGPT powered query planning agent. Given a user message provide all the question or context dependencies that would need to be addressed to provide a response to the user.
You've to break down questions into its dependent queries such that the answers of the dependent query can be used to inform the parent question.
You don't need to answer the questions, simply provide the correct sequence of questions to ask and relevant dependencies.
Call the function with appropriate data i.e. the dependencies.
"""

In [10]:
question = (
    "what's the distance between the capital of France and capital of United Kingdom?"
)
completion_messages = [
    {
        "role":
        "user",
        "content":
        f"""{DEPENDENCY_PROMPT}

        Question: {question}
        """,
    },
    {
        "role": "assistant",
        "content": ""
    },
]
function_calls = completion_obj.chatCompletion(completion_messages, functions)
function_calls

Input Tokens:  574




Generated New Tokens:  85


[{'name': 'dependencyPlanning',
  'parameters': {'dependencies': [{'id': 1,
     'question': 'What is the capital of France?'},
    {'id': 2, 'question': 'What is the capital of the United Kingdom?'}]}}]

In [11]:
question = "provide comparison between GPT-4 and Mistral-7B models benchmarks"
completion_messages = [
    {
        "role":
        "user",
        "content":
        f"""{DEPENDENCY_PROMPT}

        Question: {question}
        """,
    },
    {
        "role": "assistant",
        "content": ""
    },
]
function_calls = completion_obj.chatCompletion(completion_messages, functions)
function_calls

Input Tokens:  574
Generated New Tokens:  131


[{'name': 'dependencyPlanning',
  'parameters': {'dependencies': [{'id': 1,
     'question': 'What are the key features and capabilities of GPT-4?'},
    {'id': 2,
     'question': 'What are the key features and capabilities of Mistral-7B?'},
    {'id': 3,
     'question': 'How do the benchmarks of GPT-4 and Mistral-7B compare in terms of performance and accuracy?'}]}}]

In [12]:
question = "compare iPhone 14 pro with iPhone 15 pro"
completion_messages = [
    {
        "role":
        "user",
        "content":
        f"""{DEPENDENCY_PROMPT}

        Question: {question}
        """,
    },
    {
        "role": "assistant",
        "content": ""
    },
]
function_calls = completion_obj.chatCompletion(completion_messages, functions)
function_calls

Input Tokens:  570
Generated New Tokens:  93


[{'name': 'dependencyPlanning',
  'parameters': {'dependencies': [{'id': 1,
     'question': 'What are the key features of iPhone 14 Pro?'},
    {'id': 2, 'question': 'What are the key features of iPhone 15 Pro?'}]}}]