Source: https://github.com/lucidrains/toolformer-pytorch

In [1]:
!pip install toolformer-pytorch

Collecting toolformer-pytorch
  Downloading toolformer_pytorch-0.0.30-py3-none-any.whl.metadata (771 bytes)
Collecting beartype (from toolformer-pytorch)
  Downloading beartype-0.20.2-py3-none-any.whl.metadata (33 kB)
Collecting x-clip>=0.14.3 (from toolformer-pytorch)
  Downloading x_clip-0.14.4-py3-none-any.whl.metadata (724 bytes)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.6->toolformer-pytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.6->toolformer-pytorch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.6->toolformer-pytorch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.6->toolformer-pytorch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py

In [2]:
import torch
from toolformer_pytorch import Toolformer, PaLM

# simple calendar api call - function that returns a string

def Calendar():
    import datetime
    from calendar import day_name, month_name
    now = datetime.datetime.now()
    return f'Today is {day_name[now.weekday()]}, {month_name[now.month]} {now.day}, {now.year}.'

# prompt for teaching it to use the Calendar function from above


In [3]:
prompt = f"""
Your task is to add calls to a Calendar API to a piece of text.
The API calls should help you get information required to complete the text.
You can call the API by writing "[Calendar()]"
Here are some examples of API calls:
Input: Today is the first Friday of the year.
Output: Today is the first [Calendar()] Friday of the year.
Input: The president of the United States is Joe Biden.
Output: The president of the United States is [Calendar()] Joe Biden.
Input: [input]
Output:
"""

In [4]:
data = [
    "The store is never open on the weekend, so today it is closed.",
    "The number of days from now until Christmas is 30",
    "The current day of the week is Wednesday."
]

In [9]:
# model - here using PaLM, but any nn.Module that returns logits in the shape (batch, seq, num_tokens) is fine

model = PaLM(
    dim = 512,
    depth = 2,
    heads = 8,
    dim_head = 64
)

# toolformer

toolformer = Toolformer(
    model = model,
    model_seq_len = 256,
    teach_tool_prompt = prompt,
    tool_id = 'Calendar',
    tool = Calendar,
    finetune = True
)

# invoking this will
# (1) prompt the model with your inputs (data), inserted into [input] tag
# (2) with the sampled outputs, filter out the ones that made proper API calls
# (3) execute the API calls with the `tool` given
# (4) filter with the specialized filter function (which can be used independently as shown in the next section)
# (5) fine-tune on the filtered results

filtered_stats = toolformer(data)

100%|██████████| 131/131 [02:08<00:00,  1.02it/s]


AssertionError: your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering

In [7]:
response = toolformer.sample_model_with_api_calls("How many days until the next new years?")

# hopefully you see it invoke the calendar and utilize the response of the api call...

100%|██████████| 247/247 [01:28<00:00,  2.79it/s]
0it [00:00, ?it/s]


In [10]:
response

'swin gonzalez thanked limerick bloody lighter ultifastenreigning michimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimichimich