**ToolFormer (Pytorch)_xrsrke** (*https://github.com/xrsrke*)

**Task**: Data Generation - inference (C* from C)
*   Augment data with batch size 1
*   Inference with batch size 1 (CPU only)

**API**
*   Support add custom API
*   Calculator API
*   WolframeAlpha API






$x* = x_{1:i-1}, e(c_i, r_i), x_{i:n}$; $e(c,r) = <API> a_c(i_c) \rightarrow r </API>$ $\Rightarrow x* = x_{1:i-1}, <API> a_c(i_c) \rightarrow r </API>, x_{i:n}$

---
$“ [”, “]”, “->” \text{ instead of } “<API>”, “</API>”, “→” \Rightarrow x* = x_{1:i-1}, [a_c(i_c) \rightarrow r], x_{i:n} \Rightarrow \text{ Inference }\Rightarrow \text{ Decode }\Rightarrow x* = x_{1:i-1}, [a_c(i_c)], x_{i:n}$

---
Input: Joe Biden was born in Scranton, Pennsylvania. $\Rightarrow$ Output: Joe Biden was born in [QA("Where was Joe Biden born?")] Scranton, [QA("In which state is Scranton?")] Pennsylvania

In [None]:
#Cloning a repository and setup
!git clone https://github.com/xrsrke/toolformer.git ./xrsrke
%cd xrsrke
!pip -q install -e .

Cloning into './xrsrke'...
remote: Enumerating objects: 547, done.[K
remote: Counting objects: 100% (122/122), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 547 (delta 115), reused 112 (delta 112), pack-reused 425[K
Receiving objects: 100% (547/547), 1.37 MiB | 4.10 MiB/s, done.
Resolving deltas: 100% (352/352), done.


In [None]:
from toolformer.data_generator import DataGenerator
from toolformer.api import CalculatorAPI, WolframeAPI
from toolformer.prompt import calculator_prompt, wolframe_prompt
from toolformer.utils import yaml2dict

In [None]:
config = yaml2dict('./configs/default.yaml')

In [None]:
#Bloom-560M
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

config.json:   0%|          | 0.00/693 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/222 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

In [None]:
calculator_api = CalculatorAPI("Calculator",
                               calculator_prompt,
                               sampling_threshold=0.2,
                               filtering_threshold=0.2)
apis = [calculator_api]
generator = DataGenerator(config, model, tokenizer, apis=apis)

In [None]:
text = "From this, we have 10 - 5 minutes = 5 minutes."
augumented_text_ids = generator.generate(text)
print(tokenizer.decode(augumented_text_ids[0][0], skip_special_tokens=True))

From this, we have 10 - 5 minutes = [Calculator(10 - 5)] 5 minutes.




#**Clarify Code**

##Cloning a repository, setup and get config

In [1]:
#Cloning a repository, setup and get config
!git clone https://github.com/xrsrke/toolformer.git ./xrsrke
%cd xrsrke
!pip -q install -e .

Cloning into './xrsrke'...
remote: Enumerating objects: 547, done.[K
remote: Counting objects: 100% (122/122), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 547 (delta 115), reused 112 (delta 112), pack-reused 425[K
Receiving objects: 100% (547/547), 1.37 MiB | 1.91 MiB/s, done.
Resolving deltas: 100% (352/352), done.
/content/xrsrke
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m834.2 kB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m803.1/803.1 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m27.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m12

In [2]:
import re
from typing import Optional
from typing import List, Callable, Tuple, Union, TypedDict

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange
from torchtyping import TensorType

##Define hyperparameters

In [3]:
import yaml
def yaml2dict(file_path):
    with open(file_path, "r") as file:
        data = yaml.safe_load(file)
    return data

#load config
config = yaml2dict('/content/xrsrke/configs/default.yaml')
config

{'model': {'path': 'bigscience/bloom-560m', 'eos_token_id': 50256},
 'tokenizer': {'path': 'bigscience/bloom-560m'},
 'dataset': {'path': 'the_pile'},
 'data_generator': {'api_start_character': '[',
  'api_end_character': ']',
  'api_output_character': '→',
  'top_k_sampling': 3,
  'sampling_threshold': 0.1,
  'filtering_threshold': 0.05,
  'max_new_tokens': 100,
  'n_api_pad': 100},
 'inference': {'top_k': 5}}

In [4]:
#Hyperparameters
top_k_sampling = config["data_generator"]["top_k_sampling"]
print('top_k_sampling: ', top_k_sampling)
sampling_threshold = config["data_generator"]["sampling_threshold"]
print('sampling_threshold: ', sampling_threshold)
filtering_threshold = config["data_generator"]["filtering_threshold"]
print('filtering_threshold: ', filtering_threshold)

top_k_sampling:  3
sampling_threshold:  0.1
filtering_threshold:  0.05


##Define APIs

In [5]:
from abc import abstractclassmethod
from langchain import PromptTemplate

class BaseAPI:
    def __init__(
        self,
        name: str, # the name of the API call
        prompt_template: PromptTemplate,
        sampling_threshold: float = 0.2,
        filtering_threshold: float = 0.2,
    ):
        self.name = name
        self.prompt_template = prompt_template
        self.sampling_threshold = sampling_threshold
        self.filtering_threshold = filtering_threshold

    @abstractclassmethod
    def execute(self):
        pass

    def __call__(self, *args: str, **kargs: str) -> str:
        output = self.execute(*args, **kargs)
        return str(output)

class CalculatorAPI(BaseAPI):
    def execute(self, input: str) -> str:
        try:
          #evaluate whatever the user enters as a Python expression and return the result
          #eval(2 + 3) return 5, eval("Hello" + " World") return "Hello World"
          return eval(input)
        except:
          return ""

#API cals (calculator) prompt
calculator_prompt = """
Your task is to add calls to a Calculator API to a piece of text. The API call should help you get information required to complete the text. \n
You can call the API by writing "Calculator(operation)!" where "operation" is the type of calculation you want to perform. Here are some examples of API calls:

Input: John has 5 apples and his friend gave him 3 more. John now has 8 apples.
Ouput: John has 5 apples and his friend gave him 3 more. John now has [Calculator("5 + 3")] 8 apples.

Input: Jane needs to divide 24 pieces of candy equally among 6 kids. Each kid will get 4 pieces of candy.
Output: Jane needs to divide 24 pieces of candy equally among 6 kids. Each kid will get [Calculator(24 / 6)] 4 pieces of candy.

Input: From this, we have 4 * 30 minutes = 120 minutes.
Output: From this, we have 4 * 30 minutes = [Calculator(4 * 30)] 120 minutes.

Input: {input}
Output:
"""

# More prompts
# Input: A rectangle has a length of 6 cm and a width of 4 cm. The area of the rectangle is 24 square cm.
# Output: A rectangle has a length of 6 cm and a width of 4 cm. The area of the rectangle is [Calculator("6 * 4")] square cm.
# Input: The car traveled 200 miles in 4 hours. Its average speed was 50 miles per hour.
# Output: The car traveled 200 miles in 4 hours. Its average speed was [Calculator(200 / 4)] 50 miles per hour.

In [6]:
#Define calulator API
calculator_api = CalculatorAPI("Calculator",
                               calculator_prompt,
                               sampling_threshold=sampling_threshold,
                               filtering_threshold=filtering_threshold)

apis = [calculator_api] #list available APIs

##Load pretrained model (GPU)

In [7]:
#Bloom-560M
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

config.json:   0%|          | 0.00/693 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/222 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device: ', device)

#model to the device
model = model.to(device)

device:  cuda


In [9]:
model

BloomForCausalLM(
  (transformer): BloomModel(
    (word_embeddings): Embedding(250880, 1024)
    (word_embeddings_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (h): ModuleList(
      (0-23): 24 x BloomBlock(
        (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (self_attention): BloomAttention(
          (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
          (dense): Linear(in_features=1024, out_features=1024, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): BloomMLP(
          (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
          (gelu_impl): BloomGelu()
          (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (

##Sampling API Calls

In [11]:
##Load prompt for APIs calling for a given text
api = apis[0]

#Given text
text = "From this, we have 10 - 5 minutes = 5 minutes."
print('Given text: ', text)
print('---')
#Create API call prompt (calculator)
prompt = api.prompt_template.format(input=text)
print('API prompt: ', prompt)
print('---')
#Tokenize API call prompt
prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"][0].to(device)
print('API prompt_ids: ', prompt_ids)
print('---')
print('API prompt_ids length: ', len(prompt_ids))

Given text:  From this, we have 10 - 5 minutes = 5 minutes.
---
API prompt:  
Your task is to add calls to a Calculator API to a piece of text. The API call should help you get information required to complete the text. 

You can call the API by writing "Calculator(operation)!" where "operation" is the type of calculation you want to perform. Here are some examples of API calls:

Input: John has 5 apples and his friend gave him 3 more. John now has 8 apples.
Ouput: John has 5 apples and his friend gave him 3 more. John now has [Calculator("5 + 3")] 8 apples.

Input: Jane needs to divide 24 pieces of candy equally among 6 kids. Each kid will get 4 pieces of candy.
Output: Jane needs to divide 24 pieces of candy equally among 6 kids. Each kid will get [Calculator(24 / 6)] 4 pieces of candy.

Input: From this, we have 4 * 30 minutes = 120 minutes.
Output: From this, we have 4 * 30 minutes = [Calculator(4 * 30)] 120 minutes.

Input: From this, we have 10 - 5 minutes = 5 minutes.
Output:

-

In [14]:
#Special tokens
start_character = config["data_generator"]["api_start_character"]
print('start_character: ', start_character)
end_character = config["data_generator"]["api_end_character"]
print('end_character: ', end_character)
output_character = config["data_generator"]["api_output_character"]
print('output_character: ', output_character)

#Add a space, because when the model generate a token, it's also include a "space"
api_start_token_id = tokenizer(f' {start_character}', return_tensors="pt")["input_ids"][0].to(device)
print('tokenized_api_start_token_id: ', api_start_token_id)
api_end_token_id = tokenizer(end_character, return_tensors="pt")["input_ids"][0].to(device)
print('tokenized_api_end_token_id: ', api_end_token_id)
api_output_token_id = tokenizer(f'{output_character}', return_tensors="pt")["input_ids"][0].to(device)
print('tokenized_api_output_token_id: ', api_output_token_id)

pad_token_id = tokenizer.pad_token_id
print('tokenized_pad_token_id: ', pad_token_id)
eos_token_id = tokenizer(".\n\n")["input_ids"][0]
print('tokenized_eos_token_id: ', eos_token_id)

start_character:  [
end_character:  ]
output_character:  →
tokenized_api_start_token_id:  tensor([1111], device='cuda:0')
tokenized_api_end_token_id:  tensor([64], device='cuda:0')
tokenized_api_output_token_id:  tensor([18262], device='cuda:0')
tokenized_pad_token_id:  3
tokenized_eos_token_id:  6149


In [16]:
#top k positions have the maximum prob value that the model will call API <=> k candidate positions
def sample_api_position(prompt_ids): #input: API's prompt => output: k candidate positions & its output corresponding (x*)
  prompt_and_generated_ids = prompt_ids
  generated_ids = torch.tensor([]).to(device)
  i = torch.tensor([0]).to(device) #value 0 for the beginning of iterator over all positions in a given text
  api_pos_probs = torch.tensor([])

  with torch.no_grad():
    while True:
      #.unsqueeze(0) => add batch size (only batch size = 1)
      logits = model(input_ids=prompt_and_generated_ids.unsqueeze(0)).logits
      last_logit = logits[0, -1, :]
      probs = torch.softmax(last_logit, dim=-1)
      #logits.shape [batch size, prompt_ids length, embedding dim]
      #=> y_hat: last_logit.shape [embedding dim]
      #=> probability: probs.shape [embedding dim]

      #keep a position that prob at start token id (api_start_prob) > sampling_threshold
      api_start_prob = probs[api_start_token_id]
      if api_start_prob >= sampling_threshold:
          #append (probability values of a position i, position i)
          api_pos_probs = torch.cat([api_pos_probs, torch.tensor([api_start_prob, i]).unsqueeze(0)], dim=0)

      #generate outputs (x*)
      #next_token = torch.multinomial(probs, num_samples=1)
      next_token = torch.argmax(probs, dim=-1) #index of token with the maximum prob value for position i
      next_token = next_token.unsqueeze(0)
      prompt_and_generated_ids = torch.cat([prompt_and_generated_ids, next_token], dim=0) #append next_token into prompt_and_generated_ids
      generated_ids = torch.cat([generated_ids, next_token], dim=0) #append next_token into generated_ids (output)

      #stop condition
      if next_token == eos_token_id:
          break
      else:
          i += 1

  #keep k candidate positions
  if api_pos_probs.numel() == 0: #check if empty (api_pos_probs)
    #.numel(): total number of elements in a tensor
    api_positions = torch.tensor([]).to(device)
  else: #keep top k positions with the highest probability values
    _, indices = torch.sort(api_pos_probs[:, 0], descending=True)
    api_positions = api_pos_probs[indices[:top_k_sampling], 1]

  return api_positions.long(), generated_ids.long()

In [17]:
#Sampling API Calls
api_start_idxs, generated_ids = sample_api_position(prompt_ids)
print(api_start_idxs)
print(generated_ids, tokenizer.decode(generated_ids))

tensor([10])
tensor([ 12620,   1119,     15,   1701,   1542,   1581,    647,    973,  17405,
           564,   1111, 120009,   2623,     11,   1416,    647,    973,     12,
            64,    973,  17405,   6149], device='cuda:0') From this, we have 10 - 5 minutes = [Calculator(10 - 5)] 5 minutes.




In [18]:
def obtain_api_response(prompt_ids, positions, generated_ids):

  MAX_PAD = 50
  pre_api_ids = torch.tensor([]).to(device)

  for position in positions:
    #x*_{0:i}
    text_ids = torch.cat([generated_ids[:position],
                          api_start_token_id], dim=0)
    #padding (max: 50), left
    padded_text_ids = F.pad(text_ids, pad=(MAX_PAD - text_ids.shape[-1], 0), value=pad_token_id)
    #append padded x*_{0:i}
    pre_api_ids = torch.cat([pre_api_ids,
                              rearrange(padded_text_ids, "... -> 1 ...")]) #rearrange() => add batch size

  #prompt length
  PROMPT_LENGTH = len(prompt_ids)

  prompt_and_pre_api_ids = torch.tensor([]).to(device)

  for x in pre_api_ids:
    #append padded x*_{0:i} into prompt_ids
    prompt_and_pre_api_ids = torch.cat([prompt_and_pre_api_ids,
                                        torch.cat([prompt_ids, x]).unsqueeze(0)], dim=0)

  with torch.no_grad():
    #using prompt_ids + padded x*_{0:i} to generate the rest (max tokens = 50)
    candidate_ids = model.generate(input_ids=prompt_and_pre_api_ids.long(),
                                  eos_token_id=eos_token_id,
                                  max_new_tokens=50,
                                  )

  candidate_ids = candidate_ids[:, PROMPT_LENGTH:] #remove prompt, keep output (x*) with padding only

  return candidate_ids

In [19]:
# obtaining api responses
candidate_ids = obtain_api_response(prompt_ids, api_start_idxs, generated_ids)
print(candidate_ids)
print(tokenizer.decode(candidate_ids[0]))

tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,  12620,   1119,     15,   1701,   1542,   1581,
            647,    973,  17405,    564,   1111, 120009,   2623,     11,   1416,
            647,    973,     12,     64,    973,  17405,   6149]],
       device='cuda:0')
<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>From this, we have 10 - 5 minutes = [Calculator(10 - 5)] 5 minutes.




##Executing and Filtering API Calls


In [28]:
def _generate_conditioning_prompts(api, candidate_ids):
  #input api, output (x*) => outputs: ([a_c(i_c) → r], [a_c(i_c) →])

  conditioning_api_ids = torch.tensor([]).to(device)
  API_NAME = api.name #API name
  MAX_PAD = 100

  for text_ids in candidate_ids:
      #decode output to text
      text = tokenizer.decode(text_ids, skip_special_tokens=True)
      #extract api request content
      api_request_content = extract_api_request_content(text, api_name=API_NAME)
      #api response
      api_response = api(api_request_content)
      #encode api response
      api_response_ids = tokenizer(api_response, return_tensors="pt")["input_ids"][0].to(device)
      #concat special token (→) with api response
      api_response_with_arrow_ids = torch.cat([api_output_token_id, api_response_ids], dim=0)

      #extract api syntax
      api_syntax = extract_api_syntax(text, api_name=API_NAME)
      #encode api syntax
      api_syntax_ids = tokenizer(api_syntax, return_tensors="pt")["input_ids"][0].to(device)
      #api syntax with response [a_c(i_c) → r]
      api_syntax_with_response_ids = torch.cat([api_syntax_ids[:-1], api_response_with_arrow_ids, api_syntax_ids[-1:]])
      #api syntax without response [a_c(i_c) →]
      api_syntax_without_response_ids = torch.cat([api_syntax_ids[:-1], api_output_token_id, api_syntax_ids[-1:]])

      #padding
      padded_api_with_response = rearrange(F.pad(api_syntax_with_response_ids,
                                                 pad=((MAX_PAD - api_syntax_with_response_ids.shape[-1]), 0),
                                                 value=pad_token_id), "... -> 1 ...")
      padded_api_without_response = rearrange(F.pad(api_syntax_without_response_ids,
                                                    pad=((MAX_PAD - api_syntax_without_response_ids.shape[-1]), 0),
                                                    value=pad_token_id), "... -> 1 ...")

      #append (padded_api_without_response, padded_api_with_response) into conditioning_api_ids
      padded_api_call = torch.cat([padded_api_without_response, padded_api_with_response], dim=0)
      padded_api_call = rearrange(padded_api_call, "... -> 1 ...")
      conditioning_api_ids = torch.cat([conditioning_api_ids, padded_api_call], dim=0).long()

  return conditioning_api_ids

def extract_api_request_content(text: str, api_name: str) -> str:
  """Extract the content of an API request (i_c) from a given text."""
  start_tag = f"{api_name}("
  end_tag = ")"
  start_idx = text.find(start_tag) #first position of start_tag
  if start_idx == -1:
      return None
  start_idx += len(start_tag) #end position of start_tag
  end_idx = text.find(end_tag, start_idx) #position of end_tag
  if end_idx == -1:
      return None
  return text[start_idx:end_idx] #return api_request_content

def extract_api_syntax(text: str, api_name: str) -> str:
    """Extract the API Syntax from a given text."""
    pattern = r"\[{}\(.*?\)\]".format(api_name)
    matches = re.findall(pattern, text)
    return matches #[a_c(i_c)]

In [29]:
#Loss function
def _compute_weight(t: int) -> Union[int, float]:
  #Compute the weight in the loss function. The further away from api start idxs, the smaller the weight
  return max(0, 1-0.2*t)

def _normalize_weights(augmented_text_ids):
  #Normalize the weight of each position in a sequence (normalized weight = weight/sum weight)
  """
  {"api_start_positions":{
    candidate position (10):{
      "seq_positions": {
        candidate position (10):{
          "prompt_ids": api_and_text_ids,
          "unnormalized_weight": _compute_weight(t=j-idx),
          "losses": [],
          "target_ids": torch.tensor([next_token_ids, next_token_ids, next_token_ids])
       -> "normalized_weight": normalized_weight}
          }
          }}}}
  """
  for api_start_position in augmented_text_ids["api_start_positions"].values():
      total_weight = sum([seq_position["unnormalized_weight"] for seq_position in api_start_position["seq_positions"].values()])
      for seq_position in api_start_position["seq_positions"].values():
          seq_position["normalized_weight"] = seq_position["unnormalized_weight"] / total_weight #add {'normalized_weight': 0.0}
  return augmented_text_ids

def _calculate_weighted_loss(augmented_text_ids):
  for position in augmented_text_ids["api_start_positions"]:
    seq_positions = augmented_text_ids["api_start_positions"][position]["seq_positions"]
    for i in seq_positions:
      losses = seq_positions[i]["losses"]
      weights = seq_positions[i]["normalized_weight"]
      seq_positions[i]["weighted_losses"] = -losses * weights
  return augmented_text_ids

def _calculate_loss(augmented_text_ids):
  data = {}
  for position in augmented_text_ids["api_start_positions"]:
    seq_positions = augmented_text_ids["api_start_positions"][position]["seq_positions"]
    losses = [0, 0, 0]
    for i in seq_positions:
      losses[0] += seq_positions[i]["weighted_losses"][0] # loss for [text]
      losses[1] += seq_positions[i]["weighted_losses"][1] # loss for [api->, text]
      losses[2] += seq_positions[i]["weighted_losses"][2] # loss for [api-result, text]
    data[position] = losses
  return data

def _filter_candidate_by_threshold(losses, candidates):
  filtered_augmented_text_ids = torch.tensor([]).to(device)
  for i, position in enumerate(losses):
    negative_loss = min(losses[position][0], losses[position][1]) #min([text], [api->, text])
    positive_loss = losses[position][2] #[api-result, text]
    if negative_loss - positive_loss >= filtering_threshold:
        # filtered_augmented_text_ids.append(candidates[i])
        filtered_augmented_text_ids = torch.cat([filtered_augmented_text_ids,
                                                 candidates[i].unsqueeze(0)
                                                 ], dim=0)
  return filtered_augmented_text_ids.long()

In [30]:
def extract_conditioning_ids_and_target_ids(augmented_text_ids):
  conditioning_text_ids = torch.tensor([]).to(device) #((seq_len - api_start_idxs.item())*3, MAX_PAD=50)
  target_ids = torch.tensor([]).to(device) #(seq_len - api_start_idxs.item())*3
  for _, api_start_position_dict in augmented_text_ids["api_start_positions"].items():
      for _, seq_position_dict in api_start_position_dict["seq_positions"].items():
          #append seq_position_dict["target_ids"] into target_ids
          target_ids = torch.concat([target_ids, seq_position_dict["target_ids"]], dim=0)
          for prompt_id in seq_position_dict["prompt_ids"]:
              #append seq_position_dict["prompt_ids"] (padded) into conditioning_text_ids
              conditioning_text_ids = torch.cat([conditioning_text_ids,
                                                 F.pad(prompt_id.long(),
                                                       pad=(50-prompt_id.shape[-1], 0),
                                                       value=pad_token_id).unsqueeze(0)
                                                ], dim=0)
  return conditioning_text_ids.long(), target_ids.long()

def extract_target_logprob_from_logits(logits, target_ids):
  #probs ((seq_len - api_start_idxs.item())*3, embedding dim)
  log_probs = F.log_softmax(logits, dim=-1)
  #target probs ((seq_len - api_start_idxs.item())*3,)
  target_log_probs = log_probs[range(target_ids.shape[-1]), target_ids]
  return target_log_probs

In [31]:
def filter_api(api, text_ids, api_start_idxs, candidate_ids):
  #storage of api_calls with and without response
  conditioning_api_ids = _generate_conditioning_prompts(api, candidate_ids)
  #encode space token
  SPACE_TOKEN = tokenizer(". ", return_tensors="pt")["input_ids"][0].to(device)
  API_LENGTH = 100
  augmented_text_ids = {"api_start_positions": {}}

  for idx, api_ids in zip(api_start_idxs, conditioning_api_ids):
    idx = idx.item() #api start idxs
    seq_len = len(text_ids) #sequence length
    augmented_text_ids["api_start_positions"][idx] = {"seq_positions": {}}

    j = idx
    while j <= seq_len - 1: #a loop from api start idx to the end
      if j == 1:
          j += 1
          continue

      #𝑥_1,…,𝑥_(𝑖−1),<API> 𝑟𝑒𝑠𝑝𝑜𝑛𝑠𝑒 𝑓𝑜𝑟 𝐴𝑃𝐼 </𝐴𝑃𝐼> 𝑥_𝑖,...,𝑥_n

      #begin with tokens 𝑥_1,…,𝑥_(𝑖−1)
      conditioning_text_ids = text_ids[:j]
      api_and_text_ids = torch.stack([F.pad(conditioning_text_ids, pad=(API_LENGTH + len(SPACE_TOKEN), 0), value=pad_token_id), # [text_ids]
                                      torch.cat([api_ids[0], SPACE_TOKEN, conditioning_text_ids], dim=0), # [api->, text_ids]
                                      torch.cat([api_ids[1], SPACE_TOKEN, conditioning_text_ids], dim=0), # [api->result, text_ids]
                                      ], dim=0) #api_and_text_ids.shape = (3, API_LENGTH + len(SPACE_TOKEN) + len(conditioning_text_ids))
      #next with token x_i
      next_token_ids = text_ids[j]
      #storage
      augmented_text_ids["api_start_positions"][idx]["seq_positions"][j] = {"prompt_ids": api_and_text_ids,
                                                                            #t=j-idx: shift compared to api start idxs
                                                                            "unnormalized_weight": _compute_weight(t=j-idx),
                                                                            "losses": [],
                                                                            "target_ids": torch.tensor([next_token_ids, next_token_ids, next_token_ids]).to(device)
                                                                            }
      j += 1

  #normalize weights
  augmented_text_ids = _normalize_weights(augmented_text_ids)
  #extract conditioning ids (prompt_ids) and target ids
  conditioning_text_ids, target_ids = extract_conditioning_ids_and_target_ids(augmented_text_ids)

  #input_ids ((seq_len - api_start_idxs.item())*3, MAX_PAD=50)
  #=> output.logits ((seq_len - api_start_idxs.item())*3, MAX_PAD=50, embedding dim = 250880)
  output = model(input_ids=conditioning_text_ids.long())
  #logits ((seq_len - api_start_idxs.item())*3, embedding dim)
  logits = output.logits[:, -1, :]

  #softmax => target probs ((seq_len - api_start_idxs.item())*3,)
  log_probs = extract_target_logprob_from_logits(logits, target_ids)

  #update losses by each 3 cases
  for _, api_start_position_dict in augmented_text_ids["api_start_positions"].items():
      for _, seq_position_dict in api_start_position_dict["seq_positions"].items():
          seq_position_dict["losses"] = log_probs[:3].squeeze(0) #first 3 of (seq_len - api_start_idxs)*3
          log_probs = log_probs[3:] #remove first 3

  #calculate weighted loss for [text], [api->, text], [api-result, text]
  augmented_text_ids = _calculate_weighted_loss(augmented_text_ids)
  #calculate loss
  losses = _calculate_loss(augmented_text_ids)
  #filter candidate by threshold (loss)
  filtered_candidate_ids = _filter_candidate_by_threshold(losses, candidate_ids)
  return filtered_candidate_ids

In [47]:
#storage of filtered_apis
filtered_apis = torch.tensor([]).to(device)
#encode a given text
text_ids = tokenizer(text, return_tensors="pt")["input_ids"][0].to(device)
# filtering
filtered_candidate_ids = filter_api(api, text_ids, api_start_idxs, candidate_ids)
filtered_apis = torch.cat([filtered_apis, filtered_candidate_ids.unsqueeze(0)], dim=0)

augumented_text_ids = filtered_apis.long()
print(tokenizer.decode(augumented_text_ids[0][0], skip_special_tokens=True))

From this, we have 10 - 5 minutes = [Calculator(10 - 5)] 5 minutes.




#Check :))

In [50]:
def generate(apis, text):
  #storage of filtered_apis
  filtered_apis = torch.tensor([]).to(device)
  for api in apis:
      #batch size = 1
      prompt = api.prompt_template.format(input=text)
      prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"][0].to(device)
      # sampling positions
      api_start_idxs, generated_ids = sample_api_position(prompt_ids)
      # obtaining api responses
      candidate_ids = obtain_api_response(prompt_ids, api_start_idxs, generated_ids)
      # filtering
      text_ids = tokenizer(text, return_tensors="pt")["input_ids"][0].to(device)
      filtered_candidate_ids = filter_api(api, text_ids, api_start_idxs, candidate_ids)
      filtered_apis = torch.cat([filtered_apis, filtered_candidate_ids.unsqueeze(0)], dim=0)
  return filtered_apis.long()

In [53]:
calculator_api = CalculatorAPI("Calculator",
                               calculator_prompt,
                               sampling_threshold=0.2,
                               filtering_threshold=0.2
                               )
apis = [calculator_api]

In [None]:
"""calculator_prompt

Your task is to add calls to a Calculator API to a piece of text. The API call should help you get information required to complete the text. \n
You can call the API by writing "Calculator(operation)!" where "operation" is the type of calculation you want to perform. Here are some examples of API calls:

Input: John has 5 apples and his friend gave him 3 more. John now has 8 apples.
Ouput: John has 5 apples and his friend gave him 3 more. John now has [Calculator("5 + 3")] 8 apples.

Input: Jane needs to divide 24 pieces of candy equally among 6 kids. Each kid will get 4 pieces of candy.
Output: Jane needs to divide 24 pieces of candy equally among 6 kids. Each kid will get [Calculator(24 / 6)] 4 pieces of candy.

Input: From this, we have 4 * 30 minutes = 120 minutes.
Output: From this, we have 4 * 30 minutes = [Calculator(4 * 30)] 120 minutes.

Input: {input}
Output:
"""

# More prompts
# Input: A rectangle has a length of 6 cm and a width of 4 cm. The area of the rectangle is 24 square cm.
# Output: A rectangle has a length of 6 cm and a width of 4 cm. The area of the rectangle is [Calculator("6 * 4")] square cm.
# Input: The car traveled 200 miles in 4 hours. Its average speed was 50 miles per hour.
# Output: The car traveled 200 miles in 4 hours. Its average speed was [Calculator(200 / 4)] 50 miles per hour.

In [63]:
text = "John has 3 apples and his friend gave him 5 more. John now has 8 apples"
augumented_text_ids = generate(apis, text)
print(tokenizer.decode(augumented_text_ids[0][0], skip_special_tokens=True))

John has 3 apples and his friend gave him 5 more. John now has [Calculator("5 + 3")] 8 apples.




In [70]:
text = "John has 7 apples and his friend gave him 5 more. John now has 12 apples"
augumented_text_ids = generate(apis, text)
print(tokenizer.decode(augumented_text_ids[0][0], skip_special_tokens=True))

John has 7 apples and his friend gave him 5 more. John now has [Calculator(7 / 5)] 12 apples.




In [64]:
text = "A rectangle has a length of 2 cm and a width of 3 cm. The area of the rectangle is 6 square cm."
augumented_text_ids = generate(apis, text)
print(tokenizer.decode(augumented_text_ids[0][0], skip_special_tokens=True))

A rectangle has a length of 2 cm and a width of 3 cm. The area of the rectangle is [Calculator(2 * 3)] 6 square cm.




In [69]:
text = "A rectangle has a length of 7 cm and a width of 5 cm. The area of the rectangle is 35 square cm."
augumented_text_ids = generate(apis, text)
print(tokenizer.decode(augumented_text_ids[0][0], skip_special_tokens=True))

A rectangle has a length of 7 cm and a width of 5 cm. The area of the rectangle is [Calculator(35 / 5)] 35 square cm.




#Final code (CPU)

In [None]:
#Cloning a repository and get config
!git clone https://github.com/xrsrke/toolformer.git ./xrsrke
%cd xrsrke
!pip -q install -e .

In [None]:
import re
from typing import Optional
from typing import List, Callable, Tuple, Union, TypedDict

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange
from torchtyping import TensorType

In [None]:
#Bloom-560M
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

In [None]:
import yaml
def yaml2dict(file_path):
    with open(file_path, "r") as file:
        data = yaml.safe_load(file)
    return data

config = yaml2dict('/content/xrsrke/configs/default.yaml')
config

In [None]:
from abc import abstractclassmethod
from langchain import PromptTemplate

class BaseAPI:
    def __init__(
        self,
        name: str, # the name of the API call
        prompt_template: PromptTemplate,
        sampling_threshold: float = 0.2,
        filtering_threshold: float = 0.2,
    ):
        self.name = name
        self.prompt_template = prompt_template
        self.sampling_threshold = sampling_threshold
        self.filtering_threshold = filtering_threshold

    @abstractclassmethod
    def execute(self):
        pass

    def __call__(self, *args: str, **kargs: str) -> str:
        output = self.execute(*args, **kargs)
        return str(output)

class CalculatorAPI(BaseAPI):
    def execute(self, input: str) -> str:
        try:
          #evaluate whatever the user enters as a Python expression and return the result
          #eval(2 + 3) return 5, eval("Hello" + " World") return "Hello World"
          return eval(input)
        except:
          return ""

#API cals (calculator) prompt
calculator_prompt = """
Your task is to add calls to a Calculator API to a piece of text. The API call should help you get information required to complete the text. \n
You can call the API by writing "Calculator(operation)!" where "operation" is the type of calculation you want to perform. Here are some examples of API calls:

Input: John has 5 apples and his friend gave him 3 more. John now has 8 apples.
Ouput: John has 5 apples and his friend gave him 3 more. John now has [Calculator("5 + 3")] 8 apples.

Input: Jane needs to divide 24 pieces of candy equally among 6 kids. Each kid will get 4 pieces of candy.
Output: Jane needs to divide 24 pieces of candy equally among 6 kids. Each kid will get [Calculator(24 / 6)] 4 pieces of candy.

Input: From this, we have 4 * 30 minutes = 120 minutes.
Output: From this, we have 4 * 30 minutes = [Calculator(4 * 30)] 120 minutes.

Input: {input}
Output:
"""

# More prompts
# Input: A rectangle has a length of 6 cm and a width of 4 cm. The area of the rectangle is 24 square cm.
# Output: A rectangle has a length of 6 cm and a width of 4 cm. The area of the rectangle is [Calculator("6 * 4")] square cm.
# Input: The car traveled 200 miles in 4 hours. Its average speed was 50 miles per hour.
# Output: The car traveled 200 miles in 4 hours. Its average speed was [Calculator(200 / 4)] 50 miles per hour.

In [None]:
#Data Generator
class AugmentedCandidate(TypedDict):
    api_start_positions: int

class DataGenerator(nn.Module):
    def __init__(
        self,
        config: dict,
        model: Callable, tokenizer: Callable,
        apis: List[BaseAPI],
        device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ):
        super().__init__()
        start_character = config["data_generator"]["api_start_character"]
        end_character = config["data_generator"]["api_end_character"]
        output_character = config["data_generator"]["api_output_character"]

        # add a space, because when the model generate a token, it's also include a "space"
        self.api_start_token_id = tokenizer(f' {start_character}', return_tensors="pt")["input_ids"][0]
        self.api_end_token_id = tokenizer(end_character, return_tensors="pt")["input_ids"][0]
        self.api_output_token_id = tokenizer(f'{output_character}', return_tensors="pt")["input_ids"][0]

        self.top_k_sampling = config["data_generator"]["top_k_sampling"]
        self.sampling_threshold = config["data_generator"]["sampling_threshold"]
        self.filtering_threshold = config["data_generator"]["filtering_threshold"]

        self.apis = apis
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.device = device

        self.pad_token_id = tokenizer.pad_token_id
        self.eos_token_id = tokenizer(".\n\n")["input_ids"][0]

    def sample_api_position(
        self,
        prompt_ids: TensorType["seq_len"], # the ids of the prompt
    ) -> Tuple[
        TensorType["n_positions"], # The positions of api call
        TensorType["seq_len"] # The generated text
    ]:
        """Sampling API positions."""
        # the ids of the prompt and generated_ids
        prompt_and_generated_ids = prompt_ids
        # only the ids of the generated_ids
        generated_ids = torch.tensor([]).to(self.device)
        i = torch.tensor([0]).to(self.device)

        api_pos_probs = torch.tensor([])

        with torch.no_grad():
            while True:
                logits = self.model(
                    input_ids=prompt_and_generated_ids.unsqueeze(0),
                ).logits

                last_logit = logits[0, -1, :]
                probs = torch.softmax(last_logit, dim=-1)
                api_start_prob = probs[self.api_start_token_id]

                if api_start_prob > self.sampling_threshold:
                    api_pos_probs = torch.cat([
                        api_pos_probs,
                        torch.tensor([api_start_prob, i]).unsqueeze(0)
                    ], dim=0)

                # sampling a token
                # next_token = torch.multinomial(probs, num_samples=1)
                next_token = torch.argmax(probs, dim=-1)
                next_token = next_token.unsqueeze(0)

                prompt_and_generated_ids = torch.cat([prompt_and_generated_ids, next_token], dim=0)
                generated_ids = torch.cat([generated_ids, next_token], dim=0)

                if next_token == self.eos_token_id:
                    break
                else:
                    i += 1

        if api_pos_probs.numel() == 0:
            api_positions = torch.tensor([])
        else:
            _, indices = torch.sort(api_pos_probs[:, 0], descending=True)
            top_k_sampling = self.top_k_sampling
            api_positions = api_pos_probs[indices[:top_k_sampling], 1]

        return api_positions.long(), generated_ids.long()

    def obtain_api_response(
        self,
        prompt_ids: TensorType["seq_len"],
        positions: TensorType["n_positions"],
        generated_ids: TensorType["seq_len"]
    ) -> TensorType["n_positions", "seq_len"]:

        MAX_PAD = 50

        # the ids before the start of an api call
        pre_api_ids = torch.tensor([])

        for position in positions:
            text_ids = torch.cat([generated_ids[:position], self.api_start_token_id], dim=0)
            padded_text_ids = F.pad(text_ids, pad=(MAX_PAD - text_ids.shape[-1], 0), value=self.pad_token_id)

            pre_api_ids = torch.cat([
                pre_api_ids,
                rearrange(padded_text_ids, "... -> 1 ...")
            ])

        PROMPT_LENGTH = len(prompt_ids)

        prompt_and_pre_api_ids = torch.tensor([])
        for x in pre_api_ids:
            prompt_and_pre_api_ids = torch.cat([
                prompt_and_pre_api_ids,
                torch.cat([prompt_ids, x]).unsqueeze(0)
            ], dim=0)

        with torch.no_grad():
            candidate_ids = self.model.generate(
                input_ids=prompt_and_pre_api_ids.long(),
                eos_token_id=self.eos_token_id,
                max_new_tokens=50,
            )

        # filter out the prompt template
        # only keep the generated ids
        candidate_ids = candidate_ids[:, PROMPT_LENGTH:]

        return candidate_ids

    def _generate_conditioning_prompts(
        self,
        api: BaseAPI,
        candidate_ids: TensorType["n_candidates", "seq_len"],
    ):
        conditioning_api_ids = torch.tensor([])

        API_NAME = api.name
        MAX_PAD = 100

        def extract_api_request_content(text: str, api_name: str) -> str:
            """Extract the content of an API request from a given text."""
            start_tag = f"{api_name}("
            end_tag = ")"
            start_idx = text.find(start_tag)
            if start_idx == -1:
                return None
            start_idx += len(start_tag)
            end_idx = text.find(end_tag, start_idx)
            if end_idx == -1:
                return None
            return text[start_idx:end_idx]

        def extract_api_syntax(text: str, api_name: str) -> str:
            """Extract the API Syntax from a given text."""
            pattern = r"\[{}\(.*?\)\]".format(api_name)
            matches = re.findall(pattern, text)
            return matches

        for text_ids in candidate_ids:
            # the ids of the prediction
            text = self.tokenizer.decode(text_ids, skip_special_tokens=True)

            api_request_content = extract_api_request_content(text, api_name=API_NAME)
            api_response = api(api_request_content)
            api_response_ids = self.tokenizer(api_response, return_tensors="pt")["input_ids"][0]
            # Format: "-> [api_response]"
            api_response_with_arrow_ids = torch.cat([self.api_output_token_id, api_response_ids], dim=0)

            api_syntax = extract_api_syntax(text, api_name=API_NAME)
            api_syntax_ids = self.tokenizer(api_syntax, return_tensors="pt")["input_ids"][0]
            api_syntax_with_response_ids = torch.cat([api_syntax_ids[:-1], api_response_with_arrow_ids, api_syntax_ids[-1:]])
            api_syntax_without_response_ids = torch.cat([api_syntax_ids[:-1], self.api_output_token_id, api_syntax_ids[-1:]])

            padded_api_without_response = rearrange(
                F.pad(api_syntax_without_response_ids, pad=((MAX_PAD - api_syntax_without_response_ids.shape[-1]), 0), value=self.pad_token_id),
                "... -> 1 ..."
            )
            padded_api_with_response = rearrange(
                F.pad(api_syntax_with_response_ids, pad=((MAX_PAD - api_syntax_with_response_ids.shape[-1]), 0), value=self.pad_token_id),
                "... -> 1 ..."
            )

            padded_api_call = torch.cat([
                padded_api_without_response,
                padded_api_with_response
            ], dim=0)
            padded_api_call = rearrange(padded_api_call, "... -> 1 ...")

            conditioning_api_ids = torch.cat([conditioning_api_ids, padded_api_call], dim=0).long()

        return conditioning_api_ids

    def _filter_candidate_by_threshold(
        self,
        losses,
        candidates: TensorType["seq_len"]
    ):
        filtered_augmented_text_ids = torch.tensor([])
        for i, position in enumerate(losses):
            negative_loss = min(losses[position][0], losses[position][1])
            positive_loss = losses[position][2]

            if negative_loss - positive_loss >= self.filtering_threshold:
                # filtered_augmented_text_ids.append(candidates[i])
                filtered_augmented_text_ids = torch.cat([
                    filtered_augmented_text_ids,
                    candidates[i].unsqueeze(0)
                ], dim=0)

        return filtered_augmented_text_ids.long()

    def filter_api(
        self,
        api: BaseAPI,
        text_ids: TensorType["seq_len"],
        api_start_idxs: TensorType["n_positions"],
        candidate_ids: TensorType["n_positions", "seq_len"]
    ):
        conditioning_api_ids = self._generate_conditioning_prompts(api, candidate_ids)

        SPACE_TOKEN = self.tokenizer(". ", return_tensors="pt")["input_ids"][0]
        API_LENGTH = 100
        augmented_text_ids = {"api_start_positions": {}}

        def _compute_weight(t: int) -> Union[int, float]:
            """Compute the weight in the loss function."""
            return max(0, 1-0.2*t)

        for idx, api_ids in zip(api_start_idxs, conditioning_api_ids):
            idx = idx.item()
            seq_len = len(text_ids)
            augmented_text_ids["api_start_positions"][idx] = {
                "seq_positions": {}
            }

            j = idx
            while j <= seq_len - 1:
                # if the model predic
                if j == 1:
                    j += 1
                    continue

                # in the formua, from x_1 to x_j (include x_j)
                # => generate_ids[:j]
                conditioning_text_ids = text_ids[:j]
                api_and_text_ids = torch.stack([
                    F.pad(conditioning_text_ids, pad=(API_LENGTH + len(SPACE_TOKEN), 0), value=self.pad_token_id), # [text_ids]
                    torch.cat([api_ids[0], SPACE_TOKEN, conditioning_text_ids], dim=0), # [api->, text_ids]
                    torch.cat([api_ids[1], SPACE_TOKEN, conditioning_text_ids], dim=0), # [api->result, text_ids]
                ], dim=0)

                # the next token after x_j
                next_token_ids = text_ids[j]
                augmented_text_ids["api_start_positions"][idx]["seq_positions"][j] = {
                    "prompt_ids": api_and_text_ids,
                    "unnormalized_weight": _compute_weight(t=j-idx),
                    "losses": [],
                    "target_ids": torch.tensor([next_token_ids, next_token_ids, next_token_ids])
                }
                j += 1

        def _normalize_weights(augmented_text_ids):
            """Normalize the weight of each position in a sequence."""
            for api_start_position in augmented_text_ids["api_start_positions"].values():
                total_weight = sum([seq_position["unnormalized_weight"] for seq_position in api_start_position["seq_positions"].values()])
                for seq_position in api_start_position["seq_positions"].values():
                    seq_position["normalized_weight"] = seq_position["unnormalized_weight"] / total_weight
            return augmented_text_ids

        augmented_text_ids = _normalize_weights(augmented_text_ids)

        def extract_conditioning_ids_and_target_ids(augmented_text_ids):
            conditioning_text_ids = torch.tensor([])
            target_ids = torch.tensor([])
            for _, api_start_position_dict in augmented_text_ids["api_start_positions"].items():
                for _, seq_position_dict in api_start_position_dict["seq_positions"].items():
                    target_ids = torch.concat([target_ids, seq_position_dict["target_ids"]], dim=0)
                    for prompt_id in seq_position_dict["prompt_ids"]:
                        conditioning_text_ids = torch.cat([
                            conditioning_text_ids,
                            F.pad(prompt_id.long(), pad=(50-prompt_id.shape[-1], 0), value=self.pad_token_id).unsqueeze(0)
                        ], dim=0)
            return conditioning_text_ids.long(), target_ids.long()

        conditioning_text_ids, target_ids = extract_conditioning_ids_and_target_ids(augmented_text_ids)

        output = self.model(input_ids=conditioning_text_ids.long())
        logits = output.logits[:, -1, :]

        def extract_target_logprob_from_logits(logits, target_ids):
            log_probs = F.log_softmax(logits, dim=-1)
            target_log_probs = log_probs[range(target_ids.shape[-1]), target_ids]
            return target_log_probs

        log_probs = extract_target_logprob_from_logits(logits, target_ids)

        for _, api_start_position_dict in augmented_text_ids["api_start_positions"].items():
            for _, seq_position_dict in api_start_position_dict["seq_positions"].items():
                seq_position_dict["losses"] = log_probs[:3].squeeze(0)
                log_probs = log_probs[3:]

        def _calculate_weighted_loss(augmented_text_ids):
            for position in augmented_text_ids["api_start_positions"]:
                seq_positions = augmented_text_ids["api_start_positions"][position]["seq_positions"]
                for i in seq_positions:
                    losses = seq_positions[i]["losses"]
                    weights = seq_positions[i]["normalized_weight"]
                    seq_positions[i]["weighted_losses"] = -losses * weights

            return augmented_text_ids

        augmented_text_ids = _calculate_weighted_loss(augmented_text_ids)

        def _calculate_loss(augmented_text_ids):
            data = {}
            for position in augmented_text_ids["api_start_positions"]:
                seq_positions = augmented_text_ids["api_start_positions"][position]["seq_positions"]
                losses = [0, 0, 0]
                for i in seq_positions:
                    losses[0] += seq_positions[i]["weighted_losses"][0] # loss for [text]
                    losses[1] += seq_positions[i]["weighted_losses"][1] # loss for [api->, text]
                    losses[2] += seq_positions[i]["weighted_losses"][2] # loss for [api-result, text]
                data[position] = losses

            return data

        losses = _calculate_loss(augmented_text_ids)
        filtered_candidate_ids = self._filter_candidate_by_threshold(losses, candidate_ids)
        return filtered_candidate_ids

    def generate(
        self,
        text: str,
    ) -> TensorType["n_apis", "n_candidates", "seq_len"]:
        filtered_apis = torch.tensor([])

        for api in self.apis:
            # TODO: add support batch
            prompt = api.prompt_template.format(input=text)
            prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0]

            # sampling positions
            api_start_idxs, generated_ids = self.sample_api_position(prompt_ids)

            # obtaining api responses
            candidate_ids = self.obtain_api_response(prompt_ids, api_start_idxs, generated_ids)

            # filtering
            text_ids = self.tokenizer(text, return_tensors="pt")["input_ids"][0]

            # return prompt_ids, api_start_idxs, generated_ids, candidate_ids, text_ids
            filtered_candidate_ids = self.filter_api(api, text_ids, api_start_idxs, candidate_ids)

            filtered_apis = torch.cat([filtered_apis, filtered_candidate_ids.unsqueeze(0)], dim=0)

        return filtered_apis.long()

In [None]:
calculator_api = CalculatorAPI("Calculator",
                               calculator_prompt,
                               sampling_threshold=0.2,
                               filtering_threshold=0.2
                               )

apis = [calculator_api]

generator = DataGenerator(config, model, tokenizer, apis=apis)

text = "From this, we have 10 - 5 minutes = 5 minutes."
augumented_text_ids = generator.generate(text)
print(tokenizer.decode(augumented_text_ids[0][0], skip_special_tokens=True))