# Data Generator

> Fill in a module description hered

In [None]:
# | default_exp data_generator

In [None]:
# | hide
from nbdev.showdoc import *

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
from typing import List, Callable, Tuple

import torch
from einops import rearrange
from torchtyping import TensorType
from langchain import PromptTemplate

from toolformer.api import BaseAPI

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export
class DataGenerator:
    def __init__(self, config: dict, model: Callable, tokenizer: Callable, apis: List[BaseAPI],):
        start_character = config["data_generator"]["api_start_character"]
        end_character = config["data_generator"]["api_end_character"]
        output_character = config["data_generator"]["api_output_character"]
        
        # add a space, because when the model generate a token, it's also include a "space"
        self.api_start_token = tokenizer(f' {start_character}', return_tensors="pt")["input_ids"][0]
        self.api_end_token = tokenizer(f'{end_character}', return_tensors="pt")["input_ids"][0]
        self.api_output_character = tokenizer(f' {output_character}', return_tensors="pt")["input_ids"][0]
        
        self.top_k = config["data_generator"]["top_k"]
        self.sampling_threshold = config["data_generator"]["sampling_threshold"]
        self.filtering_threshold = config["data_generator"]["filtering_threshold"]
        
        self.apis = apis
        self.model = model
        self.tokenizer = tokenizer
        # TODO: handle for cases that the sentence contains ".\n\n"
        self.eos_token_id = tokenizer(".\n\n")["input_ids"][0]
    
    def _sampling(self, logits: TensorType["batch_size", "seq_len", "vocab_size"]):
        pass
    
    def _generate_api_position(
        self,
        prompt_ids: TensorType["batch_size", "seq_len"], # the ids of the prompt
    ) -> Tuple[
        TensorType["batch_size", "n_positions"], # The positions of api call
        TensorType["batch_size", "seq_len"] # The generated text
    ]:
        # TODO: add support batch
        generated_ids = prompt_ids
        api_positions = torch.tensor([])
        
        with torch.no_grad():    
            while True:
                logits = self.model(
                    input_ids=generated_ids.unsqueeze(0),
                ).logits

                last_logit = logits[0, -1, :]
                probs = torch.softmax(last_logit, dim=-1)
                
                # find the top k tokens for api call
                top_k_tokens = torch.topk(probs, k=5, dim=-1).indices
                
                if self.api_start_token in top_k_tokens:
                    api_position = torch.tensor([len(generated_ids)]) # the current idx
                    api_positions = torch.cat((api_positions, api_position), dim=0)
                
                # sampling a token
                # next_token = torch.multinomial(probs, num_samples=1)
                next_token = torch.argmax(probs, dim=-1)
                next_token = next_token.unsqueeze(0)
                generated_ids = torch.cat([generated_ids, next_token], dim=0)
                
                print("--------------------")
                print(f"next_token={next_token}")
                print(f"positions={api_positions}")
                print(f"text={self.tokenizer.decode(generated_ids)}")
                
                if next_token == self.eos_token_id: break
        
        return api_positions, generated_ids
    
    def _sampling_api(
        self,
        positions: TensorType["batch_size", "n_positions"],
        generated_ids: TensorType["batch_size", "seq_len"],
        prompt: PromptTemplate
    ):
        for position in positions:
            for api in self.apis:
                condition_text = generated_ids[:position]
                conditioned_prompt = prompt.format(input=condition_text)
                pass
    
    def _filter_api(
        self,
        idxs: List[int]
    ):
        pass
    
    def generate(
        self,
        prompt_tempalte: PromptTemplate,
        text: str,
    ) -> List[str]:
        prompt = prompt_tempalte.format(input=text)
        prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0]  
        
        # sampling
        # TODO: add support batch
        positions, generated_ids = self._generate_api_position(prompt_ids)
        return positions, generated_ids
        # filtering
        
        # return