---
title:
    Project Demo
description:
    Here you can try out the text generation I introduced in my research project.
    I highly recommended to look at [my research project](https://github.com/yonikremer/final_project) for more information.
show-code:
    False
params:
    model_name:
        label: Model
        input: select
        value: facebook/opt-125m
        choices: [facebook/opt-125m, facebook/opt-350m, gpt2, gpt2-medium, gpt2-large, bigscience/bloom-560m]
        multi: False
    prompt:
        label: Enter your prompt
        input: text
        value: I had so much fun in the
        rows: 4
    generation_class:
        label: Select type of text generation
        input: select
        value: SamplingGenerator
        choices: [TreeGenerator, SamplingGenerator]
        multi: False
    num_tokens:
        label: Completion length (in tokens)
        value: 4
        input: numeric
        min: 1
    group_size:
        label: Group Size
        input: numeric
        min: 2
        value: 2
    top_p:
        label: top p
        input: numeric
        min: 0
        max: 1
        value: 0
        step: 0.05
    top_k:
        label: top k
        input: numeric
        min: 1
        value: 1
    temperature:
        label: temperature
        input: numeric
        min: 0.1
        max: 100
        value: 1
        step: 0.01
---

In [None]:
model_name = "facebook/opt-125m"
prompt = "I so much fun in the"
num_tokens = 4
top_p = 0
top_k = 1
group_size = 2
temperature = 1
generation_class = 'TreeGenerator'

# Set Up

In [1]:

import timeit
from math import ceil, floor

%pip install -q -r ../demo_requirements.txt

from text_generator import compare_generators
from tree_generator import TreeGenerator
from sampling_generator import SamplingGenerator

from transformers import AutoConfig

[K     |████████████████████████████████| 4.9 MB 8.3 MB/s 
[K     |████████████████████████████████| 163 kB 25.6 MB/s 
[K     |████████████████████████████████| 6.6 MB 48.0 MB/s 
[?25h

# Use grouped sampling:

In [None]:
vocab_size = AutoConfig.from_pretrained(model_name).vocab_size

num_groups = ceil(num_tokens / group_size)
actual_top_k = min(top_k, floor(vocab_size * top_p))

if generation_class == "TreeGenerator":
    max_num_calls = sum((actual_top_k ** i) for i in range(num_groups))
    previous_methods_max_num_calls = sum((actual_top_k ** i) for i in range(num_tokens))
    non_grouped_generator = TreeGenerator(model_name, 1, top_k, top_p, temperature)

elif generation_class == "SamplingGenerator":
    max_num_calls = num_groups
    previous_methods_max_num_calls = num_tokens
    if top_p == 1.0:
        top_p = None
    elif top_p == 0.0:
        top_k = 1
        top_p = None
    elif top_k >= vocab_size:
        top_k = None
    else:
        raise ValueError("When using sampling from distribution, You must use either top k or top k and no both")
    non_grouped_generator = SamplingGenerator(model_name, group_size,
                 temperature, top_k, top_p)
else:
    raise ValueError

print(f"The model will generate {num_tokens} tokens (words or parts of words)")
print(f"It will call the model at most {max_num_calls} times to the model")
print(f"Previous methods will need up to {previous_methods_max_num_calls} call to the model to generate the same text")

compare_generators(non_grouped_generator, prompt, num_tokens, group_size)
del non_grouped_generator

# Change the hyperparameters to see what will happen!