# Inference

In [None]:
#| default_exp inference

In [None]:
#| export
from __future__ import annotations
import math, random, torch, matplotlib.pyplot as plt, numpy as np, matplotlib as mpl, shutil, os, gzip, pickle, re, copy, time
from pathlib import Path
from functools import partial
import fastcore.all as fc
from glob import glob
import json

from torch import tensor, nn, optim
import torch.nn.functional as F
from datasets import load_dataset
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, default_collate
from torch.nn import init
from torch.nn.utils.rnn import pad_sequence
from typing import List, Optional

from datetime import datetime, timedelta
import calendar
from fastprogress import progress_bar
from einops import rearrange

from toolken.model import *
from toolken.tokenizer import *
from toolken.datasets import *

Inference is the same as inference in all language models, apart from 'tool mode'. When the next predicted token is a 'toolken' (a token that represents a function call), there is some additional processing to formulate arguments, carry out the function call and return the results to the generation. "Tool mode" is itself just a nested inference task — it involves prompting the model again to formulate the arguments.

There are two prompt templates — one for the top-level task of answering the user input, and one for the nested task of formulating the arguments for the function call.

In [None]:
path = '../data/gsm8k-xl/template'
prompt_template = open(f'{path}/llama_general.txt').read()
func_template = open(f'{path}/llama_func.txt').read()

In [None]:
print(prompt_template)

Answer the following questions step by step

Question: Mark has 3 tanks for pregnant fish.  Each tank has 4 pregnant fish and each fish gives birth to 20 young.  How many young fish does he have at the end?
Answer: He has 4*3=12 pregnant fish They give birth to 12*20=240 fish #### 240

Question: The math questions in a contest are divided into three rounds: easy, average, and hard. There are corresponding points given for each round. That is 2, 3, and 5 points for every correct answer in the easy, average, and hard rounds, respectively. Suppose Kim got 6 correct answers in the easy; 2 correct answers in the average; and 4 correct answers in the difficult round, what are her total points in the contest?
Answer: Kim got 6 points/round x 2 round = 12 points in the easy round. She got 2 points/round x 3 rounds = 6 points in the average round. She got 4 points/round x 5 rounds = 20 points in the difficult round. So her total points is 12 points + 6 points + 20 points = 38 points. #### 38

Q

In [None]:
print(func_template)

Answer the following questions with <add>, <subtract>, <multiply>, <divide> operators

Question: Mark has 3 tanks for pregnant fish.  Each tank has 4 pregnant fish and each fish gives birth to 20 young.  How many young fish does he have at the end?
Answer: He has 4*3=<multiply>(4, 3)=12 pregnant fish They give birth to 12*20=<multiply>(12, 20)=240 fish #### 240

Question: The math questions in a contest are divided into three rounds: easy, average, and hard. There are corresponding points given for each round. That is 2, 3, and 5 points for every correct answer in the easy, average, and hard rounds, respectively. Suppose Kim got 6 correct answers in the easy; 2 correct answers in the average; and 4 correct answers in the difficult round, what are her total points in the contest?
Answer: Kim got 6 points/round x 2 round = <multiply>(6, 2)=12 points in the easy round. She got 2 points/round x 3 rounds = <multiply>(2, 3)=6 points in the average round. She got 4 points/round x 5 rounds = <

The authors of the original Toolken paper helpfully provided a set of test cases.

In [None]:
path = '../data/gsm8k-xl/test.json'
test_data = [json.loads(line) for line in open(path).readlines()]
raw_test_cases = [i["question"] for i in test_data]
enhanced_v = [i["enhanced_v"] for i in test_data]
test_cases = []
for v, q in zip(enhanced_v, raw_test_cases):
    for i in range(len(v)):
        q = q.replace(f"{{v_{i+1}}}", str(v[i]))
    test_cases.append(q)

In [None]:
test_cases[0]

"Janet’s ducks lay 4096 eggs per day. She eats 27 for breakfast every morning and bakes muffins for her friends every day with 64. She sells the remainder at the farmers' market daily for $8 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"

In [None]:
path = '../data/gsm8k-xl/func_dict.json'
func_dict = json.load(open(path))
func_dict

{'<add>': 0, '<subtract>': 1, '<multiply>': 2, '<divide>': 3}

In [None]:
#| export
def add(x,y): return x+y
def subtract(x,y): return x-y
def multiply(x,y): return x*y
def divide(x,y): return x/y

In [None]:
#| export
def generate(self, model, prompt, temperature=0.8, top_p=0.95, max_len=512, stop_token=[]):
    bsz = len(prompts)
    params = model.params
    assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
    
    prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts]

    min_prompt_size = min([len(t) for t in prompt_tokens])
    max_prompt_size = max([len(t) for t in prompt_tokens])

    total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
    
    tokens = torch.full((bsz, total_len), tokenizer.pad_id).to(device).long()
    start_idx = len(prompt_tokens)
    prev_idx = 0

    for current_idx in range(start_idx, max_len):
        concat_logits = model(prompt_tokens[prev_idx:current_idx], prev_idx)
        # func_logits += self.logits_bias
        if temperature > 0:
            probs = torch.softmax(concat_logits, dim=-1)
            next_token = sample_top_p(probs, top_p)
        else:
            next_token = torch.argmax(logits, dim=-1)

        tokens[:, current_idx] = next_token
        prev_idx = current_idx

        if next_token >= 32000:
            break
        if next_token in stop_token:
            break
    return tokens, current_idx

In [None]:
def decode_toolken(generated_tokens, current_idx, tokenizer, func_dict):
    decoded = []
    for i, t in list(generated_tokens[:current_idx]):
        try: decoded.append(tokenizer.decode(t))
        except IndexError: pass
        if t >= 32000: decoded.append(f'{[k for k,v in func_dict.items() if v == t-32000][0]}(')
    return decoded

In [None]:
def format_args(cur_generation, func):
    """Converts raw argument formats from LLM to a list of arguments [x,y]"""
    args = cur_generation.split(func)[-1].replace("=", "").replace(">", "").replace("((", "(").replace("))", ")")
    args = args.replace("$", "")
    if ", " in args:
        args = args.replace(", ", ";").replace(",", "").replace(";", ", ")
    args = args.replace(" ", "")
    if "(" not in args or ")" not in args:
            raise Exception("invalid args")
    if '%' in args:
        temp = args.split("(")[1].split(")")[0].split(",")
        for arg_i, arg in enumerate(temp):
            if "%" in arg:
                arg = arg.replace("%", "").strip()
                arg = str(float(arg) / 100)
            temp[arg_i] = arg
        args = f"({', '.join(temp)})"
    args = [int(x) for x in args.replace("(", "").replace(")", "").split(',')]
    return args

In [None]:
def complete_function_call(model, current_generation, func, func_template):
    # complete the arguments
    func_prompt = func_template.replace('[QUESTION]', q) + current_generation
    model.inference_mode = 'baseline'
    generated_func_tokens = generate(model, func_prompt, temperature=temperature, top_p=top_p, max_len=max_len, stop_token=[29897, 3892])
    current_generation += tokenizer.decode(list(generated_func_tokens))

    # extract args and do function call
    args = format_args(current_generation, func)
    func_name = func[1:-1]
    for f in func_list:
        if f.__name__ == func_name: res = f(*args)
    current_generation = current_generation.split(func)[0] + str(res)
    return current_generation

In [None]:
def sample(model:callable, prompt:str, tokenizer, func_template:str, func_dict:dict):
    while True:
        end_loop = True
        
        # generate up until the first toolken
        model.inference_mode = 'func_embedding'
        generated_tokens, current_idx = generate(model, prompt, temperature=temperature, top_p=top_p, max_len=max_len)
        decoded = decode_toolken(generated_tokens, current_idx, tokenizer, func_dict)
        current_generation = "".join(decoded)
        
        # "tool mode" — make function calls if present
        for func in func_dict.keys():
            if decoded[-1] == func + "(":
                end_loop = False
                complete_function_call(model, current_generation, func, func_template)
        
                # complete the generation
                
        if end_loop: break
    return current_generation