# Modelos de lenguaje

Un modelo de lenguaje es una función que estime la probabilidad de la siguiente palabra (o *token*) condicionada al texto que la precede. Aquí vamos a usar el modelo de lenguaje GPT-2 para predecir la continuación de una frase y para llamara a la atención a construcciones poco probables.

# Language models

A language model is a function that estimates the probability of the next word (or *token*) conditioned on the text that precedes it. Here we are going to use the GPT-2 language model to predict the continuation of a sentence and to draw attention to unlikely constructions.

In [None]:
from IPython.core.display import display, HTML
import ipywidgets as widgets

import sys
import random

import torch
import torch.nn.functional as F
import numpy as np

from transformers import GPT2Config
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [24]:
# basado en https://github.com/huggingface/transformers/blob/master/examples/run_generation.py

MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop

ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config,)), ())

MODEL_CLASSES = {
    'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
}


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size x vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits


def sample_sequence(model, tokenizer, length, context, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
                    device='cpu'):
    text = ''
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(1, 1)
    logits = [(tokenizer.decode(context[0, 0].item()), 0)]
    with torch.no_grad():
        for token in range(context.shape[1] + length - 1):

            if token < context.shape[1]:
                generated = context[:, :token+1]
            inputs = {'input_ids': generated}

            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
            if token < context.shape[1]-1:
                logits.append((tokenizer.decode(context[0, token+1].item()), next_token_logits[0, context[0, token+1]].item()))

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in range(1):
                for _ in set(generated[i].tolist()):
                    next_token_logits[i, _] /= repetition_penalty
                
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            if temperature == 0: # greedy sampling:
                next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            
            if token >= context.shape[1]-1:
                logits.append((tokenizer.decode(generated[0, -1].item()), next_token_logits[0, generated[0, -1]].item()))
    return logits

In [2]:
model_type = 'gpt2'
model_name_or_path = 'gpt2-xl'
length = 20
temperature = 0
repetition_penalty = 1.0
top_k = 0
top_p = 0.9
no_cuda = False

device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
n_gpu = torch.cuda.device_count()

model_type = model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[model_type]
tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
model = model_class.from_pretrained(model_name_or_path)
model.to(device)
model.eval()

if length < 0 and model.config.max_position_embeddings > 0:
    length = model.config.max_position_embeddings
elif 0 < model.config.max_position_embeddings < length:
    length = model.config.max_position_embeddings  # No generation bigger than model size 
elif length < 0:
    length = MAX_LENGTH  # avoid infinite loop

In [63]:
text = widgets.Textarea(
    layout=widgets.Layout(width="100%", height="100px")
)
output = widgets.HTML()
display(text, output)

def on_text_changed(b):
    if len(b.new) == 0:
        output.value = ''
        return
    context_tokens = tokenizer.encode(b.new, add_special_tokens=False)
    predict = 5
    logits = sample_sequence(
        model=model,
        tokenizer=tokenizer,
        context=context_tokens,
        length=predict,
        temperature=0,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        device=device,
    )
    html = ''
    for _ in range(len(logits) - predict):
        prob = np.exp(logits[_][1])/(1 + np.exp(logits[_][1])) if _ > 0 else 1
        background = 'rgb(255,' + str(int(255 * prob)) + ',' + str(int(255 * prob)) + ')'
        color = 'black' if prob > 0.5 else 'white'
        html += '<span style="background-color: '+ background + '; color: ' + color + '">' + logits[_][0] + '</span>'
    html += '<span style="color: gray">' + ''.join([logits[_][0] for _ in range(len(logits) - predict, len(logits))]) + '</span>'
    output.value = html

text.observe(on_text_changed, names="value", type="change")

Textarea(value='', layout=Layout(height='100px', width='100%'))

HTML(value='')