##**All imports go here**

In [None]:
import os
import time
import json
import re
import csv
import ast
from copy import deepcopy
from random import randrange
from functools import partial
import torch
import accelerate
import bitsandbytes as bnb        #GPU Required bitsandbytes/restart the session if error
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
from peft import (
    get_peft_model,
    PeftModel
)

device = "cuda" if torch.cuda.is_available() else "cpu"

##**Load the model**

In [None]:
mdl = "../RTaC-Models/codellama/CodeLlama-7b-Instruct-hf"
hf_repo1 = "../RTaC-Models/codellama/codellama-7b-final-stage-1"
hf_repo2 = "../RTaC-Models/codellama/codellama-7b-final-stage-2-if"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(mdl)
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(                      #modify configuration if higher GPU RAM
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
inf_model = AutoModelForCausalLM.from_pretrained(
    mdl,
    quantization_config=bnb_config,
    device_map="auto",
)
load_model1 = PeftModel.from_pretrained(inf_model, hf_repo1 , device_map='auto')
load_model = PeftModel.from_pretrained(load_model1, hf_repo2 , device_map='auto')
model = load_model

In [None]:
def get_output(prompt):
  model_input = tokenizer(prompt, return_tensors="pt").to(device)
  _ = model.eval()
  with torch.no_grad():
    out = model.generate(**model_input, max_new_tokens=200)
  op = tokenizer.decode(out[0], skip_special_tokens=True)
  return op

## **Adding/Modifying tools**

In [None]:
def add_new_tool(tools_list):
  print("Add tool_name:")
  tool_name = input()
  print("Add tool Description:")
  Description = input()
  print("Tool name:" , tool_name)
  print("tool Description:" , Description)
  print("Add arguments and their description , press E as Arg Name to Escape")
  if tool_name not in tools_list.keys():
      tools_list[tool_name] = { 'Description': Description , 'Arguments':[],'ReturnType':'_' , 'ReturnDescription':'_'}
  while True:
    args = {"ArgumentName": '', "Argument Description": '' , "ArgumentType": '', "ArgumentValue Example": ''}
    print("Input Argumet Name")
    arg_name = input()
    if arg_name == 'E':
      break
    print("Arg Name:" , arg_name)
    print("Input Argument Description Here")
    arg_desc = input()
    print("Arg Description:" , arg_desc)
    print("Input Argument type")
    arg_type = input()
    print("Argument_type:", arg_type)
    print('Input Allowed Values:')
    arg_example = input()
    print('Allowed values:', arg_example)
    args['ArgumentName'] = arg_name
    args['Argument Description'] = arg_desc
    args['ArgumentType'] = arg_type
    args['AllowedValues'] = arg_example
    tools_list[tool_name]['Arguments'].append(args)

  print("Add return type of the function")
  re_type = input()
  tools_list[tool_name]['ReturnType'] = re_type
  print("Add return description of the function")
  re_desc = input()
  tools_list[tool_name]['ReturnDescription'] = re_desc

  return tools_list

In [None]:
def add_multiple_tools(tools_list):
  print("Input the number of tools you want to add")
  n = int(input())
  for idx in range(n):
    tools_list = add_new_tool(tools_list)

  return tools_list


In [None]:
def delete_tool():
  keys = []
  print("Number of tools that you want to Delete:")
  n = int(input())
  for idx in range(n):
     print("Tool_Name:")
     name = input()
     keys.append(name)

  return keys

##**tool list**

In [None]:
tools_list= {"works_list": {
"Description": "Returns a list of work items matching the request",
"Arguments": [
{
"ArgumentName": "applies_to_part",
"Argument Description": "Filters for work belonging to any of the provided parts",
"ArgumentType": "list",
"AllowedValues": 'anything'
},
{
"ArgumentName": "created_by",
"Argument Description": "Filters for work created by any of these users",
"ArgumentType": "list",
"AllowedValues": 'anything'
},
{
"ArgumentName": "issue.priority",
"Argument Description": "Filters for issues with any of the provided priorities. Allowed values: p0, p1, p2, p3",
"ArgumentType": "list",
"AllowedValues": ['p0' , 'p1' , 'p2' ,'p3']
},
{
"ArgumentName": "issue.rev_orgs",
"Argument Description": "Filters for issues with any of the provided Rev organizations",
"ArgumentType": "list",
"AllowedValues": 'anything'
},
{
"ArgumentName": "limit",
"Argument Description": "The maximum number of works to return. The default is '50'",
"ArgumentType": "int",
"AllowedValues": 'anything'
},
{
"ArgumentName": "owned_by",
"Argument Description": "Filters for work owned by any of these users",
"ArgumentType": "list",
"AllowedValues": 'anything'
},
{
"ArgumentName": "stage.name",
"Argument Description": "Filters for records in the provided stage(s) by name",
"ArgumentType": "list",
"AllowedValues": 'anything'
},
{
"ArgumentName": "ticket.needs_response",
"Argument Description": "Filters for tickets that need a response",
"ArgumentType": "bool",
"AllowedValues": ['True' , 'False']
},
{
"ArgumentName": "ticket.rev_org",
"Argument Description": "Filters for tickets associated with any of the provided Rev organizations",
"ArgumentType": "list",
"AllowedValues": 'anything'
},
{
"ArgumentName": "ticket.severity",
"Argument Description": "Filters for tickets with any of the provided severities. Allowed values: blocker, high, low, medium",
"ArgumentType": "list",
"AllowedValues": ['blocker' , 'high' , 'mid' , 'low']
},
{
"ArgumentName": "ticket.source_channel",
"Argument Description": "Filters for tickets with any of the provided source channels",
"ArgumentType": "list",
"AllowedValues": 'anything'
},
{
"ArgumentName": "type",
"Argument Description": "Filters for work of the provided types. Allowed values: issue, ticket, task",
"ArgumentType": "list" ,
"AllowedValues": ['issue' , 'ticket' , 'task']
}
],'ReturnType':'list: list: Matching work items',
  'ReturnDescription':'Returns a list of work items matching the request'
},
"summarize_objects": {
"Description": "Summarizes a list of objects. The logic of how to summarize a particular object type is an internal implementation detail",
"Arguments": [
{
"ArgumentName": "objects",
"Argument Description": "List of objects to summarize",
"ArgumentType": "list" ,
"AllowedValues": 'anything'
}
],
  'ReturnType':'text',
  'ReturnDescription':'summarized text of the objects'

},
"prioritize_objects": {
"Description": "Returns a list of objects sorted by priority. The logic of what constitutes priority for a given object is an internal implementation detail",
"Arguments": [
{
"ArgumentName": "objects",
"Argument Description": "A list of objects to be prioritized",
"ArgumentType": "array of objects" ,
"AllowedValues": 'anything'
}
],
  'ReturnType':'List',
  'ReturnDescription':'Prioritized objects'

},
"add_work_items_to_sprint": {
"Description": "Adds the given work items to the sprint",
"Arguments": [
{
"ArgumentName": "work_ids",
"Argument Description": "A list of work item IDs to be added to the sprint",
"ArgumentType": "array of strings" ,
"AllowedValues": 'anything'
},
{
"ArgumentName": "sprint_id",
"Argument Description": "The ID of the sprint to which the work items should be added",
"ArgumentType": "string" ,
"AllowedValues": 'anything'
}
],
  'ReturnType': '',
  'ReturnDescription':''
},
"get_sprint_id": {
"Description": "Returns the ID of the current sprint" ,
"Arguments":[],
'ReturnType': '',
'ReturnDescription':''


},
"get_similar_work_items": {
"Description": "Returns a list of work items that are similar to the given work item",
"Arguments": [
{
"ArgumentName": "work_id",
"Argument Description": "The ID of the work item for which you want to find similar items",
"ArgumentType": "string",
"AllowedValues": 'anything'
}

],
                              'ReturnType': 'list',
  'ReturnDescription':'list of similar work items'
},
"search_object_by_name": {
"Description": "Given a search string, returns the ID of a matching object in the system of record. If multiple matches are found, it returns the one where the confidence is highest",
"Arguments": [
{
"ArgumentName": "query",
"Argument Description": "The search string, could be for example customerâ€™s name, part name, user name",
"ArgumentType": "string",
"AllowedValues": 'anything'
}
]
                           , 'ReturnType' : 'str',
'ReturnDescription':'ID of matching object'
},
"create_actionable_tasks_from_text": {
"Description": "Given a text, extracts actionable insights, and creates tasks for them, which are kind of a work item",
"Arguments": [
{
"ArgumentName": "text",
"Argument Description": "The text from which the actionable insights need to be created",
"ArgumentType": "string",
"AllowedValues": 'anything'
}
], 'ReturnType' : 'list',
'ReturnDescription':'tasks created from the given text'

},
"who_am_i": {
"Description": "Returns the ID of the current user",
"Arguments" : [],
              'ReturnType' : '',
'ReturnDescription':''

}
}

## **Parsers**

In [None]:
def json_to_python(json):
    final_dstrings = ""
    for tool in json:
      dstring = f"def {tool}"
      tool_dict = json[tool]
      tool_desc = tool_dict["Description"]
      ret_type = tool_dict["ReturnType"]
      ret_desc = tool_dict["ReturnDescription"]
      dstring2 = ""
      dstring3 = ""
      for arg in tool_dict["Arguments"]:
        arg_name = arg["ArgumentName"]
        arg_type = arg["ArgumentType"]
        arg_desc = arg["Argument Description"]
        arg_values = arg["AllowedValues"]
        dstring2 += f"{arg_name}, "
        dstring3 += f"  {arg_name}({arg_type}): {arg_desc}"
        if arg_values != "anything":
          dstring3 += f". Allowed Values: {arg_values}"
        dstring3 += "\n"
      dstring += f"({dstring2[:-2]})" + ":\n\"\"\"\n"
      dstring += f"{tool_desc}\n\nParameters:\n"
      dstring += dstring3 + "\nReturns:\n"
      dstring += f"  ({ret_type}): {ret_desc}\n"
      dstring += "\"\"\"\n"
      final_dstrings += "\n\n"
      final_dstrings += dstring
    return final_dstrings

In [None]:
def modify_args(args):
    s = ''
    cnt = 1
    for j in args:
        if j == '(': cnt += 1
        elif j == ')': cnt -= 1
        if cnt == 0:
            break
        s += j
    return s

def get_avl_tools():
    base = {}
    for tool in tools_list:
        base[tool]= {}
        for arg in tools_list[tool]["Arguments"]:
            base[tool][arg["ArgumentName"]] = arg["AllowedValues"]

    return base

def edit_distance(str1, str2):
    """
    Function to calculate edit distance between two strings.
    Few hardcoded variants that correct common model mistakes.
    """

    if str1 == "whoami" and str2 == "who_am_i" :
        return 0
    if str1 == "get_current_sprint_id" and str2 == "get_sprint_id":
        return 0
    if str1 == "create_actions_from_text" and str2 == "create_actionable_tasks_from_text":
        return 0
    if str1 == "work_type" and str2 == "type":
        return 0

    if len(str1) < len(str2):
        str1, str2 = str2, str1

    previous_row = list(range(len(str2) + 1))
    for i, c1 in enumerate(str1):
        current_row = [i + 1]
        for j, c2 in enumerate(str2):
            # Cost of substitutions is same as previous row and column + 1 if characters are different
            cost = 0 if c1 == c2 else 1
            current_row.append(min(current_row[j] + 1,            # Deletion
                                   previous_row[j + 1] + 1,      # Insertion
                                   previous_row[j] + cost))      # Substitution
        previous_row = current_row

    return previous_row[-1]

def general_update(name,nameslist):
    """
    Returns the closest match to the the given name in the nameslist.
    If the closest edit distance is more than 50% of the of given names length, returns None.
    """
    d = len(name)
    cur_name = name
    for key in nameslist:
        cur_d = edit_distance(name,key)
        if cur_d < d:
            d = cur_d
            cur_name = key

    if 2*d <= len(name):
        return cur_name

    return None

def update_tool(tool_name):
    """
    Gets the closest tool name to the one given using general_update
    """
    avl_tools = get_avl_tools()

    return general_update(tool_name,avl_tools.keys())

def update_arg_name(arg_name,tool_name):
    """
    Gets the closest arg name to the one given(for the given tool) using general_update
    """
    avl_tools = get_avl_tools()

    return general_update(arg_name,avl_tools[tool_name].keys())

def update_arg_val(arg_value,arg_name,tool_name,arg_index,tools,start,temp_index=None):
    """
    Returns an updated arg val corresponding to the specific tool and argument given.
    If given argument is determined to be invalid, returns "$$INV_ARG".
    Handles the cases of argument values being function calls, and recursively calls itself to handle lists
    """
    if len(arg_value) == 0:
        return None

    avl_tools = get_avl_tools()
    arg_value = arg_value.strip()
    if arg_value[0] == '[':
        if arg_value[-1] != ']':
            arg_value += ']'
        arg_value = arg_value[1:-1].strip("\"").strip("\'").split(",")

        arg_val_list = []
        for value in arg_value:
            value = value.strip().strip("\"").strip("\'")
            value = update_arg_val(value,arg_name,tool_name,arg_index,tools,start,temp_index)
            arg_val_list.append(value)

        return arg_val_list

    if arg_value.startswith("$$"):
        return arg_value

    if arg_value.find('(') != -1:
        match = re.match(r"\s*(\w+)\((.*)\)",arg_value)
        process_tool(0,match.group(1),match.group(2),tools,arg_index,start,temp_index)

        if start == "temp_":
            arg_value = f"$$PREV[{temp_index[0]}]"
        elif start == "var_":
            arg_value = f"$$PREV[{arg_index[0]}]"

    if avl_tools[tool_name][arg_name] == 'anything' or arg_value in avl_tools[tool_name][arg_name]:
        return arg_value

    return "$$INV_ARG"

def wrong_name_handler(tool_name,args,arg_index,start,temp_index=None):
    """
    Handles the case of a hallucinated tool (or any tool that was unable to be resolved by the edit distance)
    Is similar to update_arg_val but since there are no restrictions on argument names or values,
    it returns them as they are
    """
    if start == "var_":
        for var_ind in arg_index:
            args = args.replace(start+str(var_ind),f"$$PREV[{arg_index[var_ind]}]")

    elif start == "temp_":
        for temp_ind in temp_index:
            args = args.replace(start+str(temp_ind),f"$$PREV[{temp_index[temp_ind]}]")
        for var_ind in arg_index:
            args = args.replace("var_"+str(var_ind),f"$$GLOB_PREV[{arg_index[var_ind]}]")

    tool = {"tool_name": tool_name, "arguments": []}

    split_args = arg_splitter(args)

    for arg in split_args:
        if "=" in arg:

            arg_name, arg_value = arg.split("=", 1)
            arg_name = arg_name.strip()
            arg_value = arg_value.strip().replace("\"","").replace("\'","")

            if arg_value[0] == '[':
                arg_value_list = []
                for list_arg in arg_value[1:-1].split(","):
                    arg_value_list.append(list_arg)
                tool["arguments"].append({"argument_name": arg_name,"argument_value": arg_value_list})
            else:
                tool["arguments"].append({"argument_name": arg_name,"argument_value": arg_value})

    return tool

def process_tool(index,tool_name,args,tools,arg_index,start,temp_index=None):
    """
    Processes a line into a valid tool dictionary. Makes use of multiple helper functions.
    """
    args = modify_args(args)

    copy_of_tool_name = tool_name
    tool_name = update_tool(tool_name)
    if not tool_name:
        tool = wrong_name_handler(copy_of_tool_name,args,arg_index,start,temp_index)
    else:
        tool = make_tool(tool_name,args,arg_index,tools,start,temp_index)

    tools.append(tool)

    if start == "temp_":
        temp_index[index] = len(tools)-1
    else:
        arg_index[index] = len(tools)-1

    return tool

def if_handler(condition,arg_index,tools):
    """
    Returns the processed if case as a conditional_magic dictionary
    """

    condition = condition.strip()

    if condition[-1] == ':':
        condition = condition[:-1]

    if condition[0] == '(' and condition[-1] == ')':
        condition = condition[1:-1]

    for var_ind in arg_index:
        condition = condition.replace("var_"+str(var_ind),f"$$PREV[{arg_index[var_ind]}]")

    condition = condition.replace("range","$$RANGE")

    function_calls = re.findall(r"\w+\s*\([^)]*\)", condition)
    for function_call in function_calls:
        function_call = function_call.strip()
        if function_call.startswith("RANGE") or function_call.startswith("$$RANGE"):
            continue
        match = re.match(r"\s*(\w+)\((.*)\)",function_call)
        if match:
            process_tool(0,match.group(1),match.group(2),tools,arg_index,"var_")

            condition = condition.replace(function_call,f"$$PREV[{arg_index[0]}]")

    return {
        "tool_name": "conditional_magic",
        "condition": condition,
        "true": [],
        "false": []
    }

def for_handler(looping_var,arg_index,tools):
    """
    Returns the processed for case as a iterational_magic dictionary
    """

    base =  {
        "tool_name": "iterational_magic",
        "looping_var": "",
        "loop": []
    }

    colon_pos = looping_var.find(":")
    hash_pos = looping_var.find("#")
    if colon_pos != -1:
        looping_var = looping_var[:colon_pos]
    elif hash_pos != -1:
        looping_var = looping_var[:hash_pos]

    looping_var = looping_var.strip()

    for var_ind in arg_index:
        looping_var = looping_var.replace("var_"+str(var_ind),f"$$PREV[{arg_index[var_ind]}]")

    looping_var = looping_var.replace("range","$$RANGE")

    function_calls = re.findall(r"\w+\s*\([^)]*\)", looping_var)
    for function_call in function_calls:
        function_call = function_call.strip()
        if function_call.startswith("RANGE") or function_call.startswith("$$RANGE"):
            continue
        match = re.match(r"\s*(\w+)\((.*)\)",function_call)
        if match:

            process_tool(0,match.group(1),match.group(2),tools,arg_index,"var_")

            looping_var = looping_var.replace(function_call,f"$$PREV[{arg_index[0]}]")

    base["looping_var"] = looping_var

    return base

def arg_splitter(args):
    """
    Returns the args split on the basis of different argument names
    """
    split_args = []
    cur_arg = ""
    brack_count = 0
    last_comma = -1
    for i in args:
        if i == '[':
            brack_count += 1
        if i == ']':
            brack_count -= 1
        if i == ',':
            last_comma = len(cur_arg)
        if brack_count == 0 and i == ',':
            split_args.append(cur_arg)
            cur_arg = ""
            continue
        cur_arg += i
        if cur_arg.count("=")>1:
            split_args.append(cur_arg[:last_comma]+']')
            cur_arg = cur_arg[last_comma+1:]
    split_args.append(cur_arg)
    return split_args

def make_tool(tool_name,args,arg_index,tools,start,temp_index):
    """
    The correct tool name counterpart to wrong_name_handler. Returns each tool as a processed dictionary
    """
    if start == "var_":
        for var_ind in arg_index:
            args = args.replace(start+str(var_ind),f"$$PREV[{arg_index[var_ind]}]")

    elif start == "temp_":
        for temp_ind in temp_index:
            args = args.replace(start+str(temp_ind),f"$$PREV[{temp_index[temp_ind]}]")
        for var_ind in arg_index:
            args = args.replace("var_"+str(var_ind),f"$$GLOB_PREV[{arg_index[var_ind]}]")
        args = args.replace("loop_var","$$LOOP_VAR")

    tool = {"tool_name": tool_name, "arguments": []}

    split_args = arg_splitter(args)

    for arg in split_args:
        arg = arg.strip()
        if "=" in arg:
            arg_name, arg_value = arg.split("=", 1)
            arg_name = arg_name.strip()
            arg_value = arg_value.strip().strip("\"").strip("\'")
            arg_name = update_arg_name(arg_name,tool_name)
            if not arg_name:
                continue

            arg_value = update_arg_val(arg_value,arg_name,tool_name,arg_index,tools,start,temp_index)
            if not arg_value:
                continue

            tool["arguments"].append({"argument_name": arg_name, "argument_value": arg_value})

    if len(tool["arguments"]) != 0:
        return tool

    avl_tools = get_avl_tools()

    if len(avl_tools[tool_name]) == 0:
        return tool

    if len(split_args) == len(avl_tools[tool_name]):

        for arg_name,arg in zip(avl_tools[tool_name],split_args):
            arg_value = arg.strip().strip("\"").strip("\'")

            arg_value = update_arg_val(arg_value,arg_name,tool_name,arg_index,tools,start,temp_index)
            if not arg_value:
                continue

            tool["arguments"].append({"argument_name": arg_name, "argument_value": arg_value})

    return tool

def converter(string):
    """
    The driver function. Processed each line individually and calls functions on the basis of matches.
    """

    try:

        tools = []
        arg_index = {}
        inIf = False
        inElse = False
        inFor = False
        for i in string.split("\n"):

            match = re.match(r"\s*var_(\d+)\s*=\s*(\w+)\((.*)\)", i)

            if match:
                inIf = False
                inElse = False
                inFor = False
                index = int(match.group(1))
                tool_name = match.group(2)
                args = match.group(3)

                if tool_name.strip() == "if":
                    tools.append(if_handler(args,arg_index,tools))
                    ifInd = len(tools)-1
                    inIf = True
                    inFor = False
                    temp_index = {}
                    continue

                process_tool(index,tool_name,args,tools,arg_index,start="var_")
                continue

            match = re.match(r"\s*if\s*(.*)", i)

            if match:
                inIf = True
                inFor = False
                temp_index = {}

                condition = match.group(1)

                tools.append(if_handler(condition,arg_index,tools))
                ifInd = len(tools)-1
                continue

            if inIf:

                match = re.match(r"\s*temp_(\d+)\s*=\s*(\w+)\((.*)\)", i)

                if match:
                    index = int(match.group(1))
                    tool_name = match.group(2)
                    args = match.group(3)

                    process_tool(index,tool_name,args,tools[ifInd]["true"],arg_index,"temp_",temp_index)
                    continue
                match = re.match(r"\s*else:\s*",i)

                if match:
                    inElse = True
                    inIf = False
                    temp_index = {}
                    continue

            if inElse:

                match = re.match(r"\s*temp_(\d+)\s*=\s*(\w+)\((.*)\)", i)

                if match:
                    index = int(match.group(1))
                    tool_name = match.group(2)
                    args = match.group(3)

                    process_tool(index,tool_name,args,tools[ifInd]["false"],arg_index,"temp_",temp_index)
                    continue

            match = re.match(r"\s*for\s*loop_var\s*in\s*(.*)",i)

            if match:
                inIf = False
                looping_var = match.group(1)
                temp_index = {}
                tools.append(for_handler(looping_var,arg_index,tools))
                inFor = True
                forInd = len(tools)-1
                continue

            if inFor:

                match = re.match(r"\s*temp_(\d+)\s*=\s*(\w+)\((.*)\)", i)

                if match:
                    index = int(match.group(1))
                    tool_name = match.group(2)
                    args = match.group(3)

                    process_tool(index,tool_name,args,tools[forInd]["loop"],arg_index,"temp_",temp_index)
                    continue

        return json.dumps(tools, indent=2)

    except Exception as e:
        return []


##**Initialize_prompt**

In [None]:
def init_prompt(tools , query):
  prompt = '''
<s>
[INST]
Allowed Tools: ''' + json_to_python(tools) + " Query: " + query
  prompt =  prompt + ' [/INST]\n'
  return prompt

In [None]:
print(init_prompt(tools_list , "Prioritize my p0 issues"))

## **Final Run**

In [None]:
def run(query):
  prompt = init_prompt(tools_list , query)
  start = time.time()
  out = get_output(prompt)
  final_out = converter(out)
  end = time.time()
  latency = end-start

  return out , final_out , latency

##**Play**

In [None]:
print("Do You Want to add Tools: Type Yes Or No!")
choice = input()
if choice == 'Yes' :
  print("Add/Modify your tools:")
  tools_list = add_multiple_tools(tools_list)
print("Do You Want to delete Tools: Type Yes Or No!")
choice = input()
if choice == 'Yes' :
  keys = delete_tool()

  for key in keys:
    del tools_list[key]

print("Give Query:")
query = input()
print("Query: " , query)
python_output , json_output , latency = run(query)
#print("Python_output :\n" , python_output)
print("Output: \n" ,json_output)
print("latency: " , latency)