<a href="https://colab.research.google.com/github/yonikremer/final_project/blob/master/project_demo.ipynb" target="_parent"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"/></a>

# Demo - this is a demo for the project - 

Here yu can try out the sampling method I introduced in my research.
I highly recommended to look at [my research project](https://github.com/yonikremer/final_project) for more information.

## Running instructions:
Windows:

press control + F9

MacOS:

press command + F9

If you get a pop-up window, choose run anyway.

# Set Up

In [1]:
import timeit, sys, os
from math import ceil, floor

if not os.path.isdir("/content/final_project"):
    os.system("git clone https://github.com/yonikremer/final_project.git")

sys.path.append('/content/final_project')
%pip install -q -r /content/final_project/demo_requirements.txt

from final_project.text_generator import TextGenerator
from final_project.tree_generator import TreeGenerator
from final_project.sampling_generator import SamplingGenerator

from transformers import AutoConfig

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


Moving 0 files to the new cache system


0it [00:00, ?it/s]

In [2]:
def compare_generators(non_grouped_generator: TextGenerator, prompt: str, num_tokens: int, group_size: int):
    """Compares grouped and non-grouped text generators"""
    print(f"Your prompt:")
    print(prompt)

    start_non_grouped = timeit.default_timer()
    non_grouped_ans: str = non_grouped_generator(prompt, num_tokens)
    stop_non_grouped = timeit.default_timer()
    non_grouped_time = stop_non_grouped - start_non_grouped
    print(f"Text generated by Non grouped sampling in {non_grouped_time} seconds:")
    print(non_grouped_ans)

    grouped_generator = non_grouped_generator.set_group_size(group_size)
    start_grouped_generation = timeit.default_timer()
    grouped_ans: str = grouped_generator(prompt, num_tokens)
    stop_grouped_generation = timeit.default_timer()
    grouped_time = stop_grouped_generation - start_grouped_generation
    print(f"Text generated by grouped sampling in {grouped_time} seconds:")
    print(grouped_ans)

# Use grouped sampling:

In [6]:
model_name = "gpt2-medium" #@param ["facebook/opt-125m", "facebook/opt-350m", "gpt2", "gpt2-medium", "bigscience/bloom-560m"]
prompt = "I so much fun in the" #@param {type:"string"}
num_tokens = 4 #@param {type:"integer", min:1}
top_p = 0 #@param {type:"slider", min:0.0, max:1.0, step:0.05}
top_k = 1 #@param {type:"integer"}
group_size = 2 #@param {type:"integer"}
temperature = 1 #@param {type:"number", min:0.000000001}
generation_class = 'TreeGenerator' #@param ["SamplingGenerator", "TreeGenerator"]


vocab_size = AutoConfig.from_pretrained(model_name).vocab_size

num_groups = ceil(num_tokens / group_size)
actual_top_k = min(top_k, floor(vocab_size * top_p))

if generation_class == "TreeGenerator":
    max_num_calls = sum((actual_top_k ** i) for i in range(num_groups))    
    previous_methods_max_num_calls = sum((actual_top_k ** i) for i in range(num_tokens))
    non_grouped_generator = TreeGenerator(model_name, 1, top_k, top_p, temperature)

elif generation_class == "SamplingGenerator":
    max_num_calls = num_groups
    previous_methods_max_num_calls = num_tokens
    if top_p == 1.0:
        top_p = None
    elif top_p == 0.0:
        top_k = 1
        top_p = None
    elif top_k >= vocab_size:
        top_k = None
    else:
        raise ValueError("When using sampling from distribution, You must use either top k or top k and no both")
    non_grouped_generator = SamplingGenerator(model_name, group_size,
                 temperature, top_k, top_p)
else:
    raise ValueError

print(f"The model will generate {num_tokens} tokens (words or parts of words)")
print(f"It will call the model at most {max_num_calls} times to the model")
print(f"Previous methods will need up to {previous_methods_max_num_calls} call to the model to generate the same text")

compare_generators(non_grouped_generator, prompt, num_tokens, group_size)
del non_grouped_generator

The model will generate 4 tokens (words or parts of words)
It will call the model at most 15078 times to the model
Previous methods will need up to 3427469592540 call to the model to generate the same text
Your prompt:
I so much fun in the
Text generated by Non grouped sampling in 349.426794717 seconds:
I so much fun in the kitchen. I love
Text generated by grouped sampling in 163.15313102100004 seconds:
I so much fun in the kitchen It'ss


# Change the hyper-parameters to see what will happen!