<a href="https://colab.research.google.com/github/yonikremer/final_project/blob/master/project_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo - this is a demo for the project - 

Here yu can try out the sampling method I introduced in my research.
I highly recommended to look at [my research project](https://github.com/yonikremer/final_project) for more information.

## Running instructions:
Windows:

press control + F9

MacOS:

press command + F9

If you get a pop-up window, choose run anyway.

# Set Up

In [1]:
import timeit, sys, os
from math import ceil, floor

if not os.path.isdir("/content/final_project"):
    os.system("git clone https://github.com/yonikremer/final_project.git")

sys.path.append('/content/final_project')
%pip install -q -r /content/final_project/demo_requirements.txt

from final_project.text_generator import TextGenerator
from final_project.tree_generator import TreeGenerator
from final_project.sampling_generator import SamplingGenerator

from transformers import AutoConfig

[K     |████████████████████████████████| 5.3 MB 4.8 MB/s 
[K     |████████████████████████████████| 163 kB 43.6 MB/s 
[K     |████████████████████████████████| 7.6 MB 44.8 MB/s 
[?25h

In [2]:
def compare_generators(non_grouped_generator: TextGenerator, prompt: str, num_tokens: int, group_size: int):
    """Compares grouped and non-grouped text generators"""
    print(f"Your prompt:")
    print(prompt)

    start_non_grouped = timeit.default_timer()
    non_grouped_ans: str = non_grouped_generator(prompt, num_tokens)
    stop_non_grouped = timeit.default_timer()
    non_grouped_time = stop_non_grouped - start_non_grouped
    print(f"Text generated by Non grouped sampling in {non_grouped_time} seconds:")
    print(non_grouped_ans)

    non_grouped_generator.set_group_size(group_size)
    grouped_generator = non_grouped_generator
    start_grouped_generation = timeit.default_timer()
    grouped_ans: str = grouped_generator(prompt, num_tokens)
    stop_grouped_generation = timeit.default_timer()
    grouped_time = stop_grouped_generation - start_grouped_generation
    print(f"Text generated by grouped sampling in {grouped_time} seconds:")
    print(grouped_ans)

# Use grouped sampling:

In [3]:
model_name = "facebook/opt-125m" #@param ["facebook/opt-125m", "facebook/opt-350m", "gpt2", "gpt2-medium", "bigscience/bloom-560m"]
prompt = "I had so much fun in the" #@param {type:"string"}
num_tokens = 4 #@param {type:"integer", min:1}
top_p = 0 #@param {type:"slider", min:0.0, max:1.0, step:0.05}
top_k = 1 #@param {type:"integer"}
group_size = 2 #@param {type:"integer"}
temperature = 1 #@param {type:"number", min:0.000000001}
generation_class = 'TreeGenerator' #@param ["SamplingGenerator", "TreeGenerator"]


vocab_size = AutoConfig.from_pretrained(model_name).vocab_size

num_groups = ceil(num_tokens / group_size)
actual_top_k = min(top_k, floor(vocab_size * top_p))

if generation_class == "TreeGenerator":
    max_num_calls = sum((actual_top_k ** i) for i in range(num_groups))    
    previous_methods_max_num_calls = sum((actual_top_k ** i) for i in range(num_tokens))
    non_grouped_generator = TreeGenerator(model_name, 1, top_k, top_p, temperature)

elif generation_class == "SamplingGenerator":
    max_num_calls = num_groups
    previous_methods_max_num_calls = num_tokens
    if top_p == 1.0:
        top_p = None
    elif top_p == 0.0:
        top_k = 1
        top_p = None
    elif top_k >= vocab_size:
        top_k = None
    else:
        raise ValueError("When using sampling from distribution, You must use either top k or top k and no both")
    non_grouped_generator = SamplingGenerator(model_name, group_size,
                 temperature, top_k, top_p)
else:
    raise ValueError

print(f"The model will generate {num_tokens} tokens (words or parts of words)")
print(f"It will call the model at most {max_num_calls} times to the model")
print(f"Previous methods will need up to {previous_methods_max_num_calls} call to the model to generate the same text")

compare_generators(non_grouped_generator, prompt, num_tokens, group_size)
del non_grouped_generator

Downloading:   0%|          | 0.00/651 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/685 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/441 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/251M [00:00<?, ?B/s]

The model will generate 4 tokens (words or parts of words)
It will call the model at most 1 times to the model
Previous methods will need up to 1 call to the model to generate the same text
Your prompt:
I had so much fun in the
Text generated by Non grouped sampling in 7.151214390999996 seconds:
I had so much fun in the first one.

Text generated by grouped sampling in 0.6674094199999985 seconds:
I had so much fun in the first  game I


# Change the hyper-parameters to see what will happen!