# Demo - this is a demo for the project - 

Here yu can try out the sampling method I intreduced in my research.
I highly reccomend to lookup [my research project](https://github.com/yonikremer/final_project) for more information.

# Set Up

In [1]:
import timeit, sys
from math import ceil

In [None]:
!git clone https://github.com/yonikremer/final_project.git
sys.path.append('/content/final_project')

In [None]:
%pip install -q -r /content/final_project/requirements.txt

In [14]:
from final_project.text_generator import TextGenerator
from final_project.tree_generator import TreeGenerator
from final_project.sampling_generator import SamplingGenerator

In [15]:
from transformers import AutoConfig, BatchEncoding

# Important Notes:


*   Please make sure the model you choose is not too big for your hardware
*   gpt2 is the smallest version of gpt-2, gpt2-xl is the largest
*   bloom is the largest vesion of bloom



In [17]:
def compare_generators(grouped_generator: TextGenerator, non_grouped_generator: TextGenerator, prompt: str, num_tokens: int):
    """Compares grouped and non-grouped text generators"""
    print(f"Your prompt:")
    print(prompt)

    start_non_grouped = timeit.default_timer()
    non_grouped_ans: str = non_grouped_generator(prompt, num_tokens)
    stop_non_grouped = timeit.default_timer()
    non_grouped_time = stop_non_grouped - start_non_grouped
    print(f"Text generated by Non grouped sampling in {non_grouped_time} seconds:")
    print(non_grouped_ans)

    start_grouped_generation = timeit.default_timer()
    grouped_ans: str = grouped_generator(prompt, num_tokens)
    stop_grouped_generation = timeit.default_timer()
    grouped_time = stop_grouped_generation - start_grouped_generation
    print(f"Text generated by grouped sampling in {grouped_time} seconds:")
    print(grouped_ans)

# Use grouped sampling:

In [None]:
model_name = "facebook/opt-125m" #@param ["facebook/opt-125m", "facebook/opt-1.3b", "facebook/opt-350m", "gpt2", "gpt2-medium", "gpt2-large", "bigscience/bloom-1b1", "bigscience/bloom-560m"]
prompt = "I so much fun in the" #@param {type:"string"}
num_tokens = 4 #@param {type:"integer", min:1}
top_p = 0.4 #@param {type:"slider", min:0.0, max:1.0, step:0.05}
top_k = 10000000000000000000000 #@param {type:"integer"}
group_size = 2 #@param {type:"integer"}
temperature = 1 #@param {type:"number", min:0.000000001}
sampling_type = 'from distribution' #@param ["from distribution", "as tree"]


vocab_size = AutoConfig.from_pretrained(model_name).vocab_size

num_groups = ceil(num_tokens / group_size)
actual_top_k = min(top_k, vocab_size * top_p)

if sampling_type == "as tree":
    max_num_calls = sum((actual_top_k ** i) for i in range(num_groups))    
    previous_methods_max_num_calls = sum((actual_top_k ** i) for i in range(num_tokens))
    grouped_generator = TreeGen(model_name, group_size, top_k, top_p, temperature)
    non_grouped_generator = TreeGen(model_name, 1, top_k, top_p, temperature)
else:
    max_num_calls = num_groups
    previous_methods_max_num_calls = num_tokens
    if top_p == 1.0:
        top_p = None
    elif top_p == 0.0:
        top_k = 1
        top_p = None
    elif top_k >= vocab_size:
        top_k = None
    else:
        raise ValueError("When using sampling from distribution, You must use either top k or top k and no both")
    grouped_generator = SamplingGenerator(model_name, group_size,
                 temperature, top_k, top_p)
    non_grouped_generator = SamplingGenerator(model_name, group_size,
                 temperature, top_k, top_p)

print(f"The model will generate {num_tokens} tokens (words or parts of words)")
print(f"It will call the model at most {max_num_calls} times to the model")
print(f"Previous methods will need up to {previous_methods_max_num_calls} call to the model to generate the same text")

compare_generators(grouped_generator, non_grouped_generator, prompt, num_tokens)

# Change the hyper-parameters to see what will happen!