# Toolformer - dataset filtering

Main bulk of toolformer functions for building a dataset for finetuning.

In [None]:
#| default_exp filtering

In [None]:
#| export
from __future__ import annotations
import math, random, torch, matplotlib.pyplot as plt, numpy as np, matplotlib as mpl, shutil, os, gzip, pickle, re, copy, time
from pathlib import Path
from functools import partial
import fastcore.all as fc
from glob import glob

from torch import tensor, nn, optim
import torch.nn.functional as F
from datasets import load_dataset
import torchvision.transforms.functional as TF
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, default_collate
from torcheval.metrics import MulticlassAccuracy
from torch.nn import init
from torch.nn.utils.rnn import pad_sequence
from typing import List

from datetime import datetime, timedelta
import calendar
from fastprogress import progress_bar
from einops import rearrange

from toolformer.datasets import *
from toolformer.tokenizer import *
from toolformer.model import *
from toolformer.tools import *

We are going to use in-context learning to finetune the model. We'll start with a prompt that teaches the model how to use a tool, and build a dataset of examples which vary the final input value inside this prompt. This first involves choosing a token to represent the beginning and end of an instance of tool usage.

Through trial and error, I chose "<%" and "%>" because these were the shortest tokens I could find that were a) represented by a single token, b) represented only once in the vocabulary (i.e. there are no duplicates) and c) unlikely to come up otherwise.

In [None]:
path = '/home/models/foundation/LLaMA/tokenizer.model'
tokenizer = Tokenizer(path)

In [None]:
p = ['{$', '{:', '!>', '<!', '<%', '%>']

for a in p:
    counter = 0
    for i in range(32000):
        t = tokenizer.decode(i)
        if t == a: counter += 1
    print(f'{a} : {counter}')

{$ : 2
{: : 2
!> : 0
<! : 1
<% : 1
%> : 1


In [None]:
tokenizer.encode('<%', False, False), tokenizer.encode('%>', False, False)

([20577], [6580])

In [None]:
test_cases = [
    'Output: The number in the next term is 18 + 12 x 3 = 54.',
    'Output: From this, we have 4 * 30 minutes = <% Calculator(4 * 30) %> 120 minutes. ',
    'Output: From this, <% Calculator(4 * 30) %> we have 4 * 30 minutes = <% Calculator(3 * 50) %> 120 minutes. <% Calculator(9 * 80) %>',
    'Output: Today is the first <% Calendar() %> Friday of the year.',
    'Output: The president of the United States is <% Calendar() %> Joe Biden.'
] 

In [None]:
#| export
def filter_and_retain_only_first_api(prompts:List[str], api_start_char:str, api_end_char:str, start_idxs:Optional[List[int]]=None):
    """
        Takes a list of strings and attempts to find instances of API calls in the 
        pattern <start_char> api(expression) <end_char>, and returns the original strings where only the
        first instance of this pattern remains (all others are replaced with ""). If 
        no matches are found, discards the string.
    """
    prompts_with_api_calls, indexes = [], []
    s,e = api_start_char, api_end_char
    if start_idxs is None: start_idxs = [0] * len(prompts)
    for i, (prompt, idx) in enumerate(zip(prompts, start_idxs)):
        p = prompt[idx:]
        try:
            matches = re.findall(f'(?<=\s){s}\s.*?\s{e}(?=\s?)', p)
            if len(matches) >= 1:
                if len(matches) > 1:
                    for match in matches[1:]: 
                        p = p.replace(match, '', 1)
                prompt = prompt[:idx] + p
                prompts_with_api_calls.append(prompt)
                indexes.append(idx)
        except Exception: print(p)
    return prompts_with_api_calls, indexes

In [None]:
api_start_char, api_end_char = '<%', '%>'
test_cases, i = filter_and_retain_only_first_api(test_cases, api_start_char, api_end_char)
for a in test_cases: print(a)

Output: From this, we have 4 * 30 minutes = <% Calculator(4 * 30) %> 120 minutes. 
Output: From this, <% Calculator(4 * 30) %> we have 4 * 30 minutes =  120 minutes. 
Output: Today is the first <% Calendar() %> Friday of the year.
Output: The president of the United States is <% Calendar() %> Joe Biden.


In [None]:
#| export
def format_api_calls(results, prompts, api_start_char:str, api_end_char:str, start_idxs:List[int]=None):
    prompts_with_responses = []
    s,e = api_start_char, api_end_char
    for r, prompt, i in zip(results, prompts, start_idxs):
        p = prompt[i:]
        call = re.search(f'{s}.*?{e}', p).group(0)
        call_with_response = call.replace(f'{e}', '') + '→ ' + str(r) + f' {e}'
        p = p.replace(call, call_with_response)
        prompt = prompt[:i] + p
        prompts_with_responses.append(prompt)
    return prompts_with_responses

In [None]:
#| export
def make_api_calls(prompts:List[str], api_start_char:str, api_end_char:str, start_idxs:List[int]=None):
    """
        Extracts a calculator API call in the format <start_char> api(expression) <end_char> from a string,
        executes the API call and returns a new string including a response in the format
        <start_char> api(expression) → response <end_char>. 
        
        ### Example
        
        input: 'The number in the next term is 18 + 12 x 3 = << Calculator(18 + 12 x 3) >> 54.'
        output: 'The number in the next term is 18 + 12 x 3 = << Calculator(18 + 12 x 3) → 54 >> 54.'
    """
    results, indexes = [], []
    s,e = api_start_char, api_end_char
    if start_idxs is None: start_idxs = [0] * len(prompts)
    for i, (p, idx) in enumerate(zip(prompts, start_idxs)):
        p = p[idx:]
        call = re.search(f'{s}.*?{e}', p).group(0)
        func_name = re.search(f'{s}\s*(.*?)\(', call).group(1)
        expression = re.search('\((.*?)\)', call).group(1)
        try:
            if func_name.lower() == "calculator": res = Calculator(expression)
            elif func_name.lower() == "calendar": res = Calendar()
            results.append(res)
            indexes.append(i)
        except Exception: pass
    prompts, start_idxs = [prompts[i] for i in indexes], [start_idxs[i] for i in indexes]
    prompts_with_responses = format_api_calls(results, prompts, api_start_char, api_end_char, start_idxs=start_idxs)
    return prompts_with_responses, indexes

In [None]:
test_cases,i = make_api_calls(test_cases, api_start_char, api_end_char)
for a in test_cases: print(a)

Output: From this, we have 4 * 30 minutes = <% Calculator(4 * 30) → 120 %> 120 minutes. 
Output: From this, <% Calculator(4 * 30) → 120 %> we have 4 * 30 minutes =  120 minutes. 
Output: Today is the first <% Calendar() → Today is Thursday, July 20, 2023. %> Friday of the year.
Output: The president of the United States is <% Calendar() → Today is Thursday, July 20, 2023. %> Joe Biden.


In [None]:
#| export
def get_probs(token_ids, logits):
    """
        Calculates a probability distribution over the vocabulary for each position
        in the input sequence (predicts the next token), and for each position, returns 
        the probability of the next actual/observed token in the input sequence.
    """
    logits = logits[:, :-1]
    token_ids = token_ids[:, 1:]
    token_ids = rearrange(token_ids, 'b n -> b n 1')
    probs = logits.softmax(-1)
    correct_token_id_pred_prob = probs.gather(-1, token_ids)
    return rearrange(correct_token_id_pred_prob, 'b n 1 -> b n')

In [None]:
#| export
def weight_func(t): return (1. - t * 0.2).clamp(min=0.)

In [None]:
#| export
def get_weights(tokens, search_token_id, pad_id=-1, weight_func=weight_func, start_index=None):
    """
        Searches for the search_token_id in the sequence, and produces a 
        weight vector that degrades weighting off a cliff after the 
        search_token_id. Weights returned are equal for all tokens preceding
        the search_token_id, and grade down to 0 over the next 5 tokens.
    """
    # find the api_start_token
    is_token_id_mask = torch.zeros_like(tokens, dtype=bool)
    rows = torch.arange(is_token_id_mask.shape[0])
    for i in range(len(tokens)):
        idx = start_index[i]
        is_token_id_mask[i,idx:] = (tokens[i,idx:] == search_token_id)
    # generate a monotonic arange for all tokens after api_start_token
    arange = (is_token_id_mask.cumsum(dim=-1) > 0).cumsum(dim=-1)
    # set everything before the api_start_token to 0
    before_token_mask = arange == 0
    # set api_start_token to 0 in range
    arange = arange - 1
    # replace all before api_start_token with 0, so 0 up to api_start_token + 1
    arange = arange.masked_fill(before_token_mask, pad_id)
    # we now have a range like [0,0,0,0,0,0,0,(api_token)0,1,2,3,4,5...]
    weights = weight_func(arange)
    # now we have a weight vector like [1.2,1.2,1.2,1.2,1.2,1.2,1.2,(search_token_id)1,0.8,0.6,0.4,0.2,0,0,0,0...]
    return weights.masked_fill(weights == pad_id, 0.)

In [None]:
#| export
def toolformer_probability_filter(tokens_without_api_calls, tokens_with_api_calls, tokens_with_api_responses, api_start_token, api_end_token, tau_filter=1., start_idxs=None, device='cuda'):
    # get the logits
    def add_dims(x): return x[None, :] if len(x.shape) < 2 else x
    
    tokens_without_api_calls, tokens_with_api_calls, tokens_with_api_responses = map(lambda t: add_dims(t).to(device), (tokens_without_api_calls, tokens_with_api_calls, tokens_with_api_responses))
    with torch.no_grad():
        model.eval()
        logits, logits_with_api_calls, logits_with_api_responses = map(partial(model, start_pos=0), (tokens_without_api_calls, tokens_with_api_calls, tokens_with_api_responses))
    
    # get the predicted probabilities
    probs_without_api_calls = get_probs(tokens_without_api_calls, logits)
    probs_with_api_calls = get_probs(tokens_with_api_calls, logits_with_api_calls)
    probs_with_api_responses = get_probs(tokens_with_api_responses, logits_with_api_responses)
    
    # get the weightings
    weights_without_api_calls = get_weights(tokens_with_api_calls[:, 1:], api_start_token, start_index=tensor(start_idxs))
    weights_with_api_calls = get_weights(tokens_with_api_calls[:, :-1], api_end_token, start_index=tensor(start_idxs))
    weights_with_api_responses = get_weights(tokens_with_api_responses[:, :-1], api_end_token, start_index=tensor(start_idxs))
    
    for w in weights_without_api_calls: assert w.sum() > 0
    
    # calculate the loss for each version
    def loss(weights, probs): return -(weights * probs.log()).sum(-1)
    loss_original = loss(weights_without_api_calls, probs_without_api_calls)
    loss_api = loss(weights_with_api_calls, probs_with_api_calls)
    loss_response = loss(weights_with_api_responses, probs_with_api_responses)

    # toolformer filtering
    l_minus = torch.minimum(loss_original, loss_api)
    l_plus = loss_response
    t_mask = (l_minus - l_plus) >= tau_filter
    return tokens_without_api_calls[t_mask], tokens_with_api_calls[t_mask], tokens_with_api_responses[t_mask]

In [None]:
#| export
def sample(model, tokenizer, prompts: List[str], max_gen_len: int, temperature: float = 0.8, top_p: float = 0.95, decode=False, make_api_calls=False, device='cuda'):
    bsz = len(prompts)
    params = model.params
    assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

    prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts]

    min_prompt_size = min([len(t) for t in prompt_tokens])
    max_prompt_size = max([len(t) for t in prompt_tokens])

    total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
    tokens = torch.full((bsz, total_len), tokenizer.pad_id).to(device).long()
    for k, t in enumerate(prompt_tokens):
        tokens[k, : len(t)] = torch.tensor(t).long()
    input_text_mask = tokens != tokenizer.pad_id
    start_pos = min_prompt_size
    prev_pos = 0
    
    for cur_pos in range(start_pos, total_len):
        logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
        if temperature > 0:
            probs = torch.softmax(logits / temperature, dim=-1)
            next_token = sample_top_p(probs, top_p)
        else:
            next_token = torch.argmax(logits, dim=-1)
        next_token = next_token.reshape(-1)
        # only replace token if prompt has already been generated
        next_token = torch.where(
            input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
        )
        tokens[:, cur_pos] = next_token
        prev_pos = cur_pos
    return tokens if not decode else decode_tokens(tokenizer, tokens, prompt_tokens, max_gen_len)

In [None]:
#| export
@torch.no_grad()
def build_finetune_dataset(dataloader, model, tokenizer, api_start_char='<$', api_end_char='%>', return_tokens=True, device='cuda'):
    """
        Samples API calls using in-context learning, and returns a dataset
        that contains only examples for which calling the API increased the 
        model's ability to predict the next token.
    """
    finetune_data = []
    model = model.to(device)
    api_start_token = tokenizer.encode(api_start_char, False, False)[0]
    api_end_token = tokenizer.encode(api_end_char, False, False)[0]
    for it, batch in enumerate(progress_bar(dataloader, leave=False)):

        # assemble the null prompts assuming no API calls
        prompts, (data_without_api_calls, start_idxs) = batch

        data_without_api_calls = [p + d for p,d in zip(prompts, data_without_api_calls)]

        # generate samples with possible API calls, and filter to a single API call per prompt
        sampled_prompts = sample(model, tokenizer, prompts, max_gen_len=100, decode=True, device=device)
        data_with_api_calls, indexes = filter_and_retain_only_first_api(sampled_prompts, api_start_char, api_end_char, start_idxs)
        if len(data_with_api_calls) == 0: continue

        # make the api calls
        try: data_with_api_responses, indexes = make_api_calls(data_with_api_calls, api_start_char, api_end_char, indexes)
        except Exception: continue
        if len(data_with_api_responses) == 0: continue

        # retain only data where we have a) without call, b) with call and c) with response
        data_with_api_calls = [data_with_api_calls[i] for i in indexes]
        data_without_api_calls = [data_without_api_calls[i] for i in indexes]

        # convert to tokens and pad to same length
        to_tokens = lambda l: pad_sequence(encode_to_tensor(tokenizer, l), batch_first=True)
        tokens_without_api_calls, tokens_with_api_calls, tokens_with_api_responses = map(
            to_tokens, (data_without_api_calls, data_with_api_calls, data_with_api_responses)
        )
        tokens_without_api_calls, tokens_with_api_calls, tokens_with_api_responses = torch.chunk(
            pad_sequence(
                [j for i in [tokens_without_api_calls, tokens_with_api_calls, tokens_with_api_responses] for j in i], 
                batch_first=True
        ), 3, dim=0)

        # filter data via the main toolformer equation
        token_start_idxs = [encode_to_tensor(tokenizer, p).shape[-1] for p in prompts]
        token_start_idxs = [token_start_idxs[i] for i in indexes]
        finetune_tokens, finetune_tokens_with_api_calls, finetune_tokens_with_api_responses = toolformer_probability_filter(
            tokens_without_api_calls, tokens_with_api_calls, tokens_with_api_responses, api_start_token, api_end_token, start_idxs=token_start_idxs, device=device
        )

        # store the relevant data
        if len(finetune_tokens_with_api_calls) >= 1: 
            for f in finetune_tokens_with_api_calls: finetune_data.append(f.cpu())

    if return_tokens: return finetune_data
    prompts = []
    for f in finetune_data:
        l = [i.item() for i in f if not i == 0]
        prompts.append(tokenizer.decode(l))
    return prompts

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

### Test (7B)

In [None]:
import json, csv

In [None]:
os.environ['LOCAL_RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
os.environ['MASTER_ADDR'] = '172.17.0.7'
os.environ['MASTER_PORT'] = '6006'

In [None]:
local_rank, world_size = setup_model_parallel()
path = '/home/models/foundation/LLaMA/7B'
checkpoint = torch.load(f'{path}/consolidated.00.pth')
with open(Path(path) / "params.json", "r") as f: params = json.loads(f.read())
model_args = ModelArgs(max_seq_len=2048, max_batch_size=8, **params)
model_args.vocab_size = tokenizer.n_words
model = Transformer(model_args).cuda().half()
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


_IncompatibleKeys(missing_keys=[], unexpected_keys=['layers.0.attention.inner_attention.rope.freqs', 'layers.1.attention.inner_attention.rope.freqs', 'layers.2.attention.inner_attention.rope.freqs', 'layers.3.attention.inner_attention.rope.freqs', 'layers.4.attention.inner_attention.rope.freqs', 'layers.5.attention.inner_attention.rope.freqs', 'layers.6.attention.inner_attention.rope.freqs', 'layers.7.attention.inner_attention.rope.freqs', 'layers.8.attention.inner_attention.rope.freqs', 'layers.9.attention.inner_attention.rope.freqs', 'layers.10.attention.inner_attention.rope.freqs', 'layers.11.attention.inner_attention.rope.freqs', 'layers.12.attention.inner_attention.rope.freqs', 'layers.13.attention.inner_attention.rope.freqs', 'layers.14.attention.inner_attention.rope.freqs', 'layers.15.attention.inner_attention.rope.freqs', 'layers.16.attention.inner_attention.rope.freqs', 'layers.17.attention.inner_attention.rope.freqs', 'layers.18.attention.inner_attention.rope.freqs', 'layers.

In [None]:
d = []
with open('../data/dataset.csv', 'r') as file: 
    reader = csv.reader(file)
    for row in reader: d.append(row)

In [None]:
len(d)

1168

In [None]:
ds = PromptDS(d)
dl = DataLoader(ds, batch_size=8, num_workers=4)

In [None]:
dt = [tokenizer.encode(p[0], True, True) for p in ds]

(566,574)

In [None]:
dds = []
idxs = [i for i in range(0,8)]
for i in idxs:
    dds.append(ds[i])
dds_prompts = [b[0] for b in dds]

In [None]:
data = build_finetune_dataset(dl, model, tokenizer, return_tokens=False)

> [0;32m/tmp/ipykernel_4261/763301092.py[0m(18)[0;36mbuild_finetune_dataset[0;34m()[0m
[0;32m     16 [0;31m        [0mprompts[0m[0;34m,[0m [0;34m([0m[0mdata_without_api_calls[0m[0;34m,[0m [0mstart_idxs[0m[0;34m)[0m [0;34m=[0m [0mbatch[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 18 [0;31m        [0mdata_without_api_calls[0m [0;34m=[0m [0;34m[[0m[0mp[0m [0;34m+[0m [0md[0m [0;32mfor[0m [0mp[0m[0;34m,[0m[0md[0m [0;32min[0m [0mzip[0m[0;34m([0m[0mprompts[0m[0;34m,[0m [0mdata_without_api_calls[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m[0;34m[0m[0m
[0m[0;32m     20 [0;31m        [0;31m# generate samples with possible API calls, and filter to a single API call per prompt[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  prompts[0]


'Your task is to add calls to a Calculator API to a piece of text. \nThe calls should help you get information required to complete the text. \nYou can call the API by writing "<% Calculator(expression) %>" where "expression" is the expression to be computed.\nYou should simply return the same text with the API call included.\nHere are some examples of API calls: \nInput: The number in the next term is 18 + 12 x 3 = 54. \nOutput: The number in the next term is 18 + 12 x 3 = <% Calculator(18 + 12 * 3) %> 54. \nInput: The population is 658,893 people. This is 11.4% of the national average of 5,763,868 people. \nOutput: The population is 658,893 people. This is 11.4% of the national average of <% Calculator(658,893 / 11.4) %> 5,763,868 people. \nInput: A total of 252 qualifying matches were played, and 723 goals were scored (an average of 2.87 per match). This is three times less than the 2169 goals last year. \nOutput: A total of 252 qualifying matches were played, and 723 goals were sco

ipdb>  data_without_api_calls[0]


'Janet sells 16.0 - 3.0 - 4.0 = 9.0 duck eggs a day. She makes 9.0 * 2.0 = $18.0 every day at the farmer’s market. '


ipdb>  start_idxs[0]


tensor(1621)


ipdb>  n


> [0;32m/tmp/ipykernel_4261/763301092.py[0m(21)[0;36mbuild_finetune_dataset[0;34m()[0m
[0;32m     19 [0;31m[0;34m[0m[0m
[0m[0;32m     20 [0;31m        [0;31m# generate samples with possible API calls, and filter to a single API call per prompt[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 21 [0;31m        [0msampled_prompts[0m [0;34m=[0m [0msample[0m[0;34m([0m[0mmodel[0m[0;34m,[0m [0mtokenizer[0m[0;34m,[0m [0mprompts[0m[0;34m,[0m [0mmax_gen_len[0m[0;34m=[0m[0;36m100[0m[0;34m,[0m [0mdecode[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m [0mdevice[0m[0;34m=[0m[0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     22 [0;31m        [0mdata_with_api_calls[0m[0;34m,[0m [0mindexes[0m [0;34m=[0m [0mfilter_and_retain_only_first_api[0m[0;34m([0m[0msampled_prompts[0m[0;34m,[0m [0mapi_start_char[0m[0;34m,[0m [0mapi_end_char[0m[0;34m,[0m [0mstart_idxs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;3

ipdb>  data_without_api_calls[0]


'Your task is to add calls to a Calculator API to a piece of text. \nThe calls should help you get information required to complete the text. \nYou can call the API by writing "<% Calculator(expression) %>" where "expression" is the expression to be computed.\nYou should simply return the same text with the API call included.\nHere are some examples of API calls: \nInput: The number in the next term is 18 + 12 x 3 = 54. \nOutput: The number in the next term is 18 + 12 x 3 = <% Calculator(18 + 12 * 3) %> 54. \nInput: The population is 658,893 people. This is 11.4% of the national average of 5,763,868 people. \nOutput: The population is 658,893 people. This is 11.4% of the national average of <% Calculator(658,893 / 11.4) %> 5,763,868 people. \nInput: A total of 252 qualifying matches were played, and 723 goals were scored (an average of 2.87 per match). This is three times less than the 2169 goals last year. \nOutput: A total of 252 qualifying matches were played, and 723 goals were sco

ipdb>  quit


In [None]:
tokenizer.vocab

In [None]:
torch.full((8,680), -1).to('cuda').long()

tensor([[140107892849952, 94768956301296, 0,  ..., 140107892849952, 140107892849952, 197568495720],
        [32, 48, 94768957192976,  ..., 80, 65, 94768956049216],
        [94768956410192, 94768955679216, 1,  ..., 140107892849984, 140107892849984, -1],
        ...,
        [0, 7310593858020254331, 3616445622929465956,  ..., 94768955884784, -1, -1],
        [140107892850000, 140107892850000, -1,  ..., 0, 0, 0],
        [140107892850016, 94768956019856, 1,  ..., 94768956178224, 0, 94768956394992]],
       device='cuda:0')

In [None]:
sample(model, tokenizer, dds_prompts, 100)

> [0;32m/tmp/ipykernel_4970/252355549.py[0m(15)[0;36msample[0;34m()[0m
[0;32m     13 [0;31m    [0mtotal_len[0m [0;34m=[0m [0mmin[0m[0;34m([0m[0mparams[0m[0;34m.[0m[0mmax_seq_len[0m[0;34m,[0m [0mmax_gen_len[0m [0;34m+[0m [0mmax_prompt_size[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     14 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 15 [0;31m    [0mtokens[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfull[0m[0;34m([0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0mtotal_len[0m[0;34m)[0m[0;34m,[0m [0mtokenizer[0m[0;34m.[0m[0mpad_id[0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdevice[0m[0;34m)[0m[0;34m.[0m[0mlong[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m    [0;32mfor[0m [0mk[0m[0;34m,[0m [0mt[0m [0;32min[0m [0menumerate[0m[0;34m([0m[0mprompt_tokens[0m[0;34m)[0m

ipdb>  n


> [0;32m/tmp/ipykernel_4970/252355549.py[0m(16)[0;36msample[0;34m()[0m
[0;32m     14 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     15 [0;31m    [0mtokens[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfull[0m[0;34m([0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0mtotal_len[0m[0;34m)[0m[0;34m,[0m [0mtokenizer[0m[0;34m.[0m[0mpad_id[0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdevice[0m[0;34m)[0m[0;34m.[0m[0mlong[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 16 [0;31m    [0;32mfor[0m [0mk[0m[0;34m,[0m [0mt[0m [0;32min[0m [0menumerate[0m[0;34m([0m[0mprompt_tokens[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m        [0mtokens[0m[0;34m[[0m[0mk[0m[0;34m,[0m [0;34m:[0m [0mlen[0m[0;34m([0m[0mt[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtens

ipdb>  tokens


tensor([[140107892849952, 140107892849952, 468151435369,  ..., 450971566112, 137438953582, 472446402661],
        [94768934595040, 94766974209360, 140106014609952,  ..., 94768934595040, 94766974209360, 140106014609632],
        [0, 0, 0,  ..., 140107892850016, 94768515512032, 1],
        ...,
        [0, 0, 0,  ..., 94768956475408, 0, 94768934595040],
        [94768956109072, 0, 94768934595040,  ..., 140107892850016, 94768956911072, 1],
        [0, 0, 0,  ..., 94768955756000, 0, 94768934595040]], device='cuda:0')


ipdb>  quit


In [None]:
for p in data[:2]: print(p)

In [None]:
with open('/home/libs/toolformer/data/finetune_dataset.csv', 'w', newline='') as file: 
    writer = csv.writer(file)
    for d in data: writer.writerow(d)