# Text Generator App

In [None]:
from bokeh.plotting import figure, show
from bokeh.models import ColumnDataSource

from math import pi

import numpy as np

import tensorflow as tf

from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
from transformers import tf_top_k_top_p_filtering

import panel as pn
import panel.widgets as pnw
from panel.template import DefaultTheme

In [None]:
#tokenizer and model for DistilGPT2
gpt2_distil_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
gpt2_distil_model = TFGPT2LMHeadModel.from_pretrained(
    "distilgpt2", pad_token_id=gpt2_distil_tokenizer.eos_token_id)

In [None]:
#tokenizer and model for GPT2(small)
gpt2_small_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_small_model = TFGPT2LMHeadModel.from_pretrained(
    "gpt2", pad_token_id=gpt2_small_tokenizer.eos_token_id)

In [None]:
#tokenizer and model for GPT2-Medium
gpt2_medium_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
gpt2_medium_model = TFGPT2LMHeadModel.from_pretrained(
    "gpt2-medium", pad_token_id=gpt2_medium_tokenizer.eos_token_id)

In [None]:
#adapted from tf_top_k_top_p_filtering function to filter probabilities rathen than logits
#source - https://huggingface.co/transformers/v2.9.1/_modules/transformers/modeling_tf_utils.html
def pr_top_k_top_p_filtering(probabilities,
                             top_k=0,
                             top_p=1.0,
                             filter_value=-float("Inf"),
                             min_tokens_to_keep=1):
    """ Filter a distribution of probabilities using top-k and/or nucleus (top-p) filtering
        Args:
            probabilities: probabilities distribution shape (batch size, vocabulary size)
            if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
            Make sure we keep at least min_tokens_to_keep per batch example in the output
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    probabilities_shape = shape_list(probabilities)

    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep),probabilities_shape[-1])  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = probabilities < tf.math.top_k(probabilities, k=top_k)[0][..., -1,None]
        probabilities = set_tensor_by_indices_to_value(probabilities, indices_to_remove,filter_value)

    if top_p < 1.0:
        sorted_indices = tf.argsort(probabilities, direction="DESCENDING")
        sorted_probabilities = tf.gather(
            probabilities, sorted_indices, axis=-1, batch_dims=1
        )  # expects logits to be of dim (batch_size, vocab_size)

        cumulative_probs = tf.math.cumsum(sorted_probabilities, axis=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p

        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove = tf.concat(
                [
                    tf.zeros_like(
                        sorted_indices_to_remove[:, :min_tokens_to_keep]),
                    sorted_indices_to_remove[:, min_tokens_to_keep:],
                ],
                -1,
            )
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove = tf.roll(sorted_indices_to_remove,1,axis=-1)
        sorted_indices_to_remove = tf.concat(
            [
                tf.zeros_like(sorted_indices_to_remove[:, :1]),
                sorted_indices_to_remove[:, 1:]
            ],
            -1,
        )
        # scatter sorted tensors to original indexing
        indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
        probabilities = set_tensor_by_indices_to_value(probabilities, indices_to_remove,filter_value)
    return probabilities

def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]


def set_tensor_by_indices_to_value(tensor, indices, value):
    # create value_tensor since tensor value assignment is not possible in TF
    value_tensor = tf.zeros_like(tensor) + value
    return tf.where(indices, value_tensor, tensor)


def scatter_values_on_batch_indices(values, batch_indices):
    shape = shape_list(batch_indices)
    # broadcast batch dim to shape
    broad_casted_batch_dims = tf.reshape(
        tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape),
        [1, -1])
    # transform batch_indices to pair_indices
    pair_indices = tf.transpose(
        tf.concat(
            [broad_casted_batch_dims,
             tf.reshape(batch_indices, [1, -1])], 0))
    # scatter values to pair indices
    return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)

In [None]:
def get_next_token_logits(sequence='Please input some text',
                          model=gpt2_small_model,
                          tokenizer=gpt2_small_tokenizer):
    '''
    Get the logits to derive the probabilities for each prediction.
    - the last layer of the logits output provides the logits for the next token
    '''
    input_ids = tokenizer.encode(sequence, return_tensors="tf")
    next_token_logits = model(input_ids).logits[:, -1, :]
    return next_token_logits

In [None]:
def get_prediction(next_token_logits=[[]],
                   decoding_type='Greedy Search',
                   model=gpt2_small_model,
                   tokenizer=gpt2_small_tokenizer,
                   temperature=0.7,
                   top_k=50,
                   top_p=0.95):
    """
    Get the prediction for the next word
    Greedy search -  returns the word with highest probability
    Sampling - sample based on filtering
    """
    tf.random.set_seed(60)
    if decoding_type == "Greedy Search":
        next_token = tf.math.argmax(next_token_logits,
                                    axis=-1,
                                    output_type=tf.int32)
    elif decoding_type == "Sampling":
        # apply a Temperature
        if temperature != 1.0:
            next_token_logits = next_token_logits / temperature
        # filter to extract logits
        filtered_next_token_logits = tf_top_k_top_p_filtering(
            next_token_logits, top_k, top_p)
        next_token = tf.random.categorical(filtered_next_token_logits,
                                           dtype=tf.int32,
                                           num_samples=1)
    prediction = tokenizer.decode(next_token.numpy().tolist()[0])
    return prediction

In [None]:
def filter_next_token_probabilities(next_token_logits=[[]],
                                    decoding_type='Greedy Search',
                                    model=gpt2_small_model,
                                    tokenizer=gpt2_small_tokenizer,
                                    temperature=0.7,
                                    top_k=50,
                                    top_p=0.95):
    """
    Get the probabilities for the next word options
    Greedy search -  returns the word with highest probability
    Sampling - sample based on filtering
    """
    if decoding_type == "Greedy Search":
        filtered_next_token_probabilities = tf.nn.softmax(next_token_logits)
    elif decoding_type == "Sampling":
        # apply a Temperature
        if temperature != 1.0:
            next_token_logits = next_token_logits / temperature
        # convert logits to probabilities
        next_token_probabilities = tf.nn.softmax(next_token_logits)
        # filter to extract probabilities
        filtered_next_token_probabilities = pr_top_k_top_p_filtering(
            next_token_probabilities, top_k, top_p)
    return filtered_next_token_probabilities

In [None]:
def get_plot_data(filtered_next_token_probabilities, tokenizer):
    """
    Get the list of words and probabilities for plotting
    """
    #convert -inf values to zero
    probabilities = tf.nn.relu(filtered_next_token_probabilities)
    #count non-zero values
    k = tf.math.count_nonzero(probabilities).numpy()
    #to limit the print barplot on screen to 100 words
    k = min(100, k)
    #extract top k probabilities
    filtered_probabilities_data = tf.math.top_k(probabilities[0], k)
    #convert probabilities to list
    filtered_probabilities = filtered_probabilities_data.values.numpy()
    probability_list = filtered_probabilities.tolist()
    #get list of words
    word_list = list()
    for i in filtered_probabilities_data.indices.numpy():
        word_list.append(tokenizer.decode([i]))
    return word_list, probability_list


def clean_plot_data(word_list, probability_list):
    """
    Prepares the data for plotting
    - Aggregates words that appear multiple times
    """
    result = {}
    for w, p in zip(word_list, probability_list):
        if w not in result:
            result[w] = p
        else:
            result[w] += p

    sorted_keys = sorted(result, key=result.get, reverse=True)
    result = {k: result[k] for k in sorted_keys}
    return list(result.keys()), list(result.values())

In [None]:
def get_plot(word_list, probability_list):
    """
    Get the plot data for Bokeh plot
    """    
    source = ColumnDataSource(
        data=dict(word_list=word_list, probability_list=probability_list))
    plot = figure(x_range=source.data['word_list'],
                  height=250,
                  title="Probabilities",
                  #define toolbar location
                  toolbar_location="right",
                  #add interactivity to the plot by different tools
                  tools="hover,pan,wheel_zoom,box_zoom,zoom_in,zoom_out,reset",
                  #define displayed text while hovering through the plot
                  tooltips="@word_list: @probability_list{0.0000}")
    plot.vbar(x='word_list',
              top='probability_list',
              width=0.8,
              color='#d45781',
              source=source)
    plot.xaxis.major_label_orientation = pi / 2
    return plot

In [None]:
def get_perplexity(sequence, gpt2_model, gpt2_tokenizer):
    """
    Get perplexity for the given sequence
    """    
    input_ids = gpt2_tokenizer.encode(sequence, return_tensors="tf")
    loss = gpt2_model(input_ids=input_ids, labels=input_ids).loss
    perplexity = tf.math.exp(tf.math.reduce_mean(loss)).numpy()
    return perplexity

In [None]:
pn.extension(sizing_mode='stretch_width')

In [None]:
#widget to choose the model
model_pn = pn.widgets.Select(options=['DistilGPT2', 'GPT2', 'GPT2-Medium'])

#widget to choose the decoding method
decoding_pn = pn.widgets.RadioBoxGroup(name='RadioBoxGroup',
                                       options=['Greedy Search', 'Sampling'],
                                       inline=True)

# widgets for sampling methods to control the model output for the predictions
temperature_pn = pnw.FloatSlider(name='Temperature',
                                 value=1.0,
                                 start=0.01,
                                 end=3.0,
                                 step=0.01,
                                 bar_color='#d45781')
top_k_pn = pnw.IntSlider(name='Top K',
                         value=0,
                         start=0,
                         end=100,
                         bar_color='#d45781')
top_p_pn = pnw.FloatSlider(name='Top p',
                           value=1.0,
                           start=0.0,
                           end=1.0,
                           step=0.01,
                           bar_color='#d45781')
next_words_pn = pnw.IntSlider(name='Number of next predicted words',
                              value=1,
                              start=1,
                              end=30,
                              bar_color='#d45781')

#text input widget for initial prompt
text_input = pn.widgets.TextInput(placeholder='Enter a string here...')

#pane for generated text output
generated_text = pn.pane.HTML(object=text_input.value,
                              background='#f0f0f0',
                              min_height=200,
                              sizing_mode="stretch_width")

text_input.link(generated_text, value='object')

#button widge for string the text generation
button = pn.widgets.Button(name="Generate", button_type='primary')

# Bokeh bar plot initialization and Bokeh plot pane
word_list = list()
probability_list = list()
plot = get_plot(word_list, probability_list)
bokeh_plot = pn.pane.Bokeh(plot, sizing_mode="stretch_width")

#pane for perplexity calculation output display
perplexity_pn = pn.pane.HTML(object="Perplexity: ")

In [None]:
#Panels for creating the parts of the application
model_widget = pn.Column("#Model", model_pn, )

decoding_widget = pn.Column("##","#Decoding Method",decoding_pn)

#create empty panel, and fill with widgets if Sampling is chosen
parameter_widgets = pn.Column()

#Header for Sampling parameters, created separately to be able to hide/show depending on Sampling is chosen or not
header_parameters = pn.pane.Markdown("""
##

#Parameters
""")

#Panel for setting the repetition of the text generation process
repeat_widget = pn.Column("##","#Repeat Prediction",next_words_pn)

In [None]:
#function to hide/show the parameters depending on the chosen decoding method
def hide_parameters(event):
    if event.new == 'Sampling':
        parameter_widgets.append(header_parameters)
        parameter_widgets.append(temperature_pn)
        parameter_widgets.append(top_k_pn)
        parameter_widgets.append(top_p_pn)
    else:
        parameter_widgets.remove(header_parameters)
        parameter_widgets.remove(temperature_pn)
        parameter_widgets.remove(top_k_pn)
        parameter_widgets.remove(top_p_pn)  
        
#tying hide_parameters function to the decoding_pn
decoding_pn.param.watch(hide_parameters, 'value')

In [None]:
# function to update data by button click
def click_cb(event):
    for i in range(next_words_pn.value):
        bokeh_plot.loading = True
        if model_pn.value == 'DistilGPT2':
            next_token_logits = get_next_token_logits(generated_text.object,
                                                      gpt2_distil_model,
                                                      gpt2_distil_tokenizer)
            pred = get_prediction(next_token_logits, decoding_pn.value,
                                  gpt2_distil_model, gpt2_distil_tokenizer,
                                  temperature_pn.value, top_k_pn.value,
                                  top_p_pn.value)
            filtered_next_token_probabilities = filter_next_token_probabilities(
                next_token_logits, decoding_pn.value, gpt2_distil_model,
                gpt2_distil_tokenizer, temperature_pn.value, top_k_pn.value,
                top_p_pn.value)
            word_list, probability_list = get_plot_data(
                filtered_next_token_probabilities, gpt2_distil_tokenizer)
            generated_text.object += pred
            perplexity = get_perplexity(generated_text.object,
                                        gpt2_distil_model,
                                        gpt2_distil_tokenizer)
        elif model_pn.value == 'GPT2':
            next_token_logits = get_next_token_logits(generated_text.object,
                                                      gpt2_small_model,
                                                      gpt2_small_tokenizer)
            pred = get_prediction(next_token_logits, decoding_pn.value,
                                  gpt2_small_model, gpt2_small_tokenizer,
                                  temperature_pn.value, top_k_pn.value,
                                  top_p_pn.value)
            filtered_next_token_probabilities = filter_next_token_probabilities(
                next_token_logits, decoding_pn.value, gpt2_small_model,
                gpt2_small_tokenizer, temperature_pn.value, top_k_pn.value,
                top_p_pn.value)
            word_list, probability_list = get_plot_data(
                filtered_next_token_probabilities, gpt2_small_tokenizer)
            generated_text.object += pred
            perplexity = get_perplexity(generated_text.object,
                                        gpt2_small_model, gpt2_small_tokenizer)
        elif model_pn.value == 'GPT2-Medium':
            next_token_logits = get_next_token_logits(generated_text.object,
                                                      gpt2_medium_model,
                                                      gpt2_medium_tokenizer)
            pred = get_prediction(next_token_logits, decoding_pn.value,
                                  gpt2_medium_model, gpt2_medium_tokenizer,
                                  temperature_pn.value, top_k_pn.value,
                                  top_p_pn.value)
            filtered_next_token_probabilities = filter_next_token_probabilities(
                next_token_logits, decoding_pn.value, gpt2_medium_model,
                gpt2_medium_tokenizer, temperature_pn.value, top_k_pn.value,
                top_p_pn.value)
            word_list, probability_list = get_plot_data(
                filtered_next_token_probabilities, gpt2_medium_tokenizer)
            generated_text.object += pred
            perplexity = get_perplexity(generated_text.object,
                                        gpt2_medium_model,
                                        gpt2_medium_tokenizer)

        perplexity_pn.object = f"Perplexity:\n {perplexity}"

        word_list, probability_list = clean_plot_data(word_list,
                                                      probability_list)

        bokeh_plot.object = get_plot(word_list, probability_list)
        bokeh_plot.loading = False


# update data by button click
button.on_click(click_cb)

In [None]:
# callback function in case the text input changes
def text_change_cb(event):
    generated_text.object = event.new

# tying the callback function to the text_input widget
text_input.param.watch(text_change_cb, 'value')

In [None]:
#Build applcation using FastListTemplate
application = pn.template.FastListTemplate(
    title="Text Generator",
    sidebar=[model_widget, decoding_widget, parameter_widgets, repeat_widget],
    main=[text_input, button, generated_text, bokeh_plot, perplexity_pn],
    theme=DefaultTheme,
    theme_toggle=False,
    accent_base_color='#d45781',
    header_background='#d45781')

In [None]:
application.show()