<a href="https://colab.research.google.com/github/yonikremer/grouped_sampling/blob/master/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Use grouped sampling:

In [None]:
!pip install -q transformers grouped_sampling

import timeit, sys, os
from math import ceil, floor

from transformers import AutoConfig
from grouped_sampling import SamplingGenerator


def compare_generators(
        non_grouped_generator: SamplingGenerator,
        prompt: str,
        num_tokens: int,
        group_size: int
        ):
    """Compares grouped and non-grouped text generators"""
    print(f"Your prompt:")
    print(prompt)

    start_non_grouped = timeit.default_timer()
    non_grouped_ans: str = non_grouped_generator(
        prompt_s=prompt,
        max_new_tokens=num_tokens,
        return_full_text=False
    )["generated_text"]
    stop_non_grouped = timeit.default_timer()
    non_grouped_time = stop_non_grouped - start_non_grouped
    print(f"Text generated by Non grouped sampling"
          f" in {non_grouped_time} seconds:")
    print(non_grouped_ans)

    non_grouped_generator.group_size = group_size
    grouped_generator = non_grouped_generator
    start_grouped_generation = timeit.default_timer()
    grouped_ans: str = grouped_generator(
        prompt_s=prompt,
        max_new_tokens=num_tokens,
        return_full_text=False
    )["generated_text"]
    stop_grouped_generation = timeit.default_timer()
    grouped_time = stop_grouped_generation - start_grouped_generation
    print(f"Text generated by grouped sampling"
          f" in {grouped_time} seconds:")
    print(grouped_ans)


model_name = "gpt2-medium" #@param ["facebook/opt-125m", "facebook/opt-350m", "gpt2", "gpt2-medium"]
prompt = "" #@param {type:"string"}
num_tokens = 100 #@param {type:"integer", min:1}
top_p = 1 #@param {type:"slider", min:0.0, max:1.0, step:0.05}
top_k = 100000 #@param {type:"integer"}
group_size = 100 #@param {type:"integer"}
temperature = 1 #@param {type:"number", min:0.000000001}

vocab_size = AutoConfig.from_pretrained(model_name).vocab_size

num_groups = ceil(num_tokens / group_size)
actual_top_k = min(top_k, floor(vocab_size * top_p))
max_num_calls = num_groups
previous_methods_max_num_calls = num_tokens
if top_p == 1.0:
    top_p = None
elif top_p == 0.0:
    top_k = 1
    top_p = None
elif top_k >= vocab_size:
    top_k = None
else:
    raise ValueError("When using sampling from distribution, You must use either top k or top k and no both")
non_grouped_generator = SamplingGenerator(
    model_name=model_name,
    group_size=group_size,
    temp=temperature, 
    top_k=top_k, 
    top_p=top_p,
    end_of_sentence_stop=False,
    )

print(f"The model will generate {num_tokens} tokens (words or parts of words)")
print(f"It will call the model at most {max_num_calls} times to the model")
print(f"Previous methods will need up to {previous_methods_max_num_calls} call to the model to generate the same text")

compare_generators(non_grouped_generator, prompt, num_tokens, group_size)
del non_grouped_generator

# Change the hyper-parameters to see what will happen!