<a href="https://colab.research.google.com/github/middlebury-csci-0451/middlebury-csci-0451.github.io/blob/main/lecture-notes/text-generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/middlebury-csci-0451/CSCI-0451/blob/main/lecture-notes/text-generation.ipynb" target="_parent">Open these notes in Google Colab</a>

<a href="https://colab.research.google.com/github/middlebury-csci-0451/CSCI-0451/blob/main/lecture-notes/text-generation-live.ipynb" target="_parent">Open the live version in Google Colab</a>

## Text Generation

In this set of notes, we'll see a simple example of how to design and train models that perform *text generation*. Large language models (often called *chatbots*) are one familiar technology that uses text generation, while autocomplete features on websites and your devices are another. The text generation task is: 

> Given a text prompt, return a sequence of text that appears realistic as a follow-up to that prompt. 

Except for a brief foray into unsupervised learning, almost all of our attention in this course has been focused on prediction problems. At first glance, it may not appear that text generation involves any prediction at all. However, modern approaches to text generation rely fundamentally on supervised learning through the framework of *next token prediction*. 

## Next Token Prediction

The *next token prediction* problem is to predict a single *token* in terms of previous tokens. A *token* is a single "unit" of text. What counts as a unit is somewhat flexible. In some cases, each token might be a single character: "a" is a token, "b" is a token, etc. In other cases, each token might be a word. [Many modern models do something in between and let tokens represent common short sequences of characters using *[byte-pair encoding](https://huggingface.co/learn/nlp-course/chapter6/5?fw=pt)*.]{.aside}

For this set of lecture notes, we're going to treat *words* and *punctuation* as tokens. The next token prediction problem is: 

> Given a sequence of tokens, predict the next token in the sequence. 

For example, suppose that our sequence of tokens is 

> "A computer science student"

We'd like to predict the next token in the sequence. Some likely candidates: 

- "*is*"
- "*codes*" 
- "*will*"

etc. On the other hand, some unlikely candidates: 

- "*mango*"
- "*grassy*"
- "*tree*"

So, we can think of this as a prediction, even a classification problem: the sequence "*A computer science student*" might be classified as "the category of sequences that are likely to be followed by the word *is*". 

Once we have trained a model, the text generation task involves asking that model to make predictions, using those predictions to form new tokens, and then feeding those new tokens into the model again to get even more new tokens, etc. 

In [1]:
import pandas as pd
import torch
import numpy as np
import string
from torchsummary import summary
from torchtext.vocab import build_vocab_from_iterator
import torch.utils.data as data
from torch import nn
from torch.nn.functional import relu
import re

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Our Task

Today, we are going to see whether we can teach an algorithm to understand and reproduce the pinnacle of cultural achievement; the benchmark against which all art is to be judged; the mirror that reveals to humany its truest self. I speak, of course, of *Star Trek: Deep Space Nine.*

<figure class="image" style="width:300px">
  <img src="https://raw.githubusercontent.com/PhilChodrow/PIC16B/master/_images/DS9.jpg" alt="">
  <figcaption><i></i></figcaption>
</figure>

In particular, we are going to attempt to teach a neural  network to generate *episode scripts*. This a text generation task: after training, our hope is that our model will be able to create scripts that are reasonably realistic in their appearance. 


In [2]:
## miscellaneous data cleaning

start_episode = 20
num_episodes = 25

url = "https://github.com/PhilChodrow/PIC16B/blob/master/datasets/star_trek_scripts.json?raw=true"
star_trek_scripts = pd.read_json(url)

cleaned = star_trek_scripts["DS9"].str.replace("\n\n\n\n\n\nThe Deep Space Nine Transcripts -", "")
cleaned = cleaned.str.split("\n\n\n\n\n\n\n").str.get(-2)
text = "\n\n".join(cleaned[start_episode:(start_episode + num_episodes)])
for char in ['\xa0', 'à', 'é', "}", "{"]:
    text = text.replace(char, "")

This is a *long* string of text. 

In [3]:
len(text)

788662

Here's what it looks like when printed: 

In [4]:
print(text[0:500])

  Last
time on Deep Space Nine.  
SISKO: This is the emblem of the Alliance for Global Unity. They call
themselves the Circle. 
O'BRIEN: What gives them the right to mess up our station? 
ODO: They're an extremist faction who believe in Bajor for the
Bajorans. 
SISKO: I can't loan you a Starfleet runabout without knowing where you
plan on taking it. 
KIRA: To Cardassia Four to rescue a Bajoran prisoner of war. 
(The prisoners are rescued.) 
KIRA: Come on. We have a ship waiting. 
JARO: What you 


The string in raw form doesn't look quite as nice: 

In [5]:
text[0:100]

'  Last\ntime on Deep Space Nine.  \nSISKO: This is the emblem of the Alliance for Global Unity. They c'

## Data Prep 

### Tokenization

In order to feed this string into a language model, we are going to need to split it into tokens. For today, we are going to treat punctuation, newline `\n` characters, and words as tokens. Here's a hand-rolled tokenizer that achieves this: 

In [6]:
def tokenizer(text):
    
    # empty list of tokens
    out = []
    
    # start by splitting into lines and candidate tokens
    # candidate tokens are separated by spaces
    L = [s.split() for s in text.split("\n")]
    
    # for each list of candidate tokens 
    for line in L:
        # scrub punctuation off beginning and end, adding to out as needed
        for token in line:             
            while (len(token) > 0) and (token[0] in string.punctuation):
                out.append(token[0])
                token = token[1:]
            
            stack = []
            while (len(token) > 0) and (token[-1] in string.punctuation):
                stack.insert(0, token[-1]) 
                token = token[:-1]
            
            out.append(token)
            if len(stack) > 0:
                out += stack
        out += ["\n"]
    
    # return the list of tokens, except for the final \n
    return out[:-1]

Here's this tokenizer in action: 

In [7]:
tokenizer("Last\ntime on Deep Space Nine. \n SISKO: This")

['Last',
 '\n',
 'time',
 'on',
 'Deep',
 'Space',
 'Nine',
 '.',
 '\n',
 'SISKO',
 ':',
 'This']

Let's tokenize the entire string: 

In [8]:
token_seq = tokenizer(text)

### Assembling the Data Set 

What we're now going to do is assemble the complete list of tokens into a series of predictor sequences and target tokens. The code below does this. The `WINDOW` controls how long each predictor sequence should be, and the `STEP` controls how many sequences we extract. A `STEP` of 1 would be all possible sequences. I've increased the `STEP` to 50 to reduce the size of our data for practical purposes. 

In [9]:
seq_len = 10
STEP = 1

predictors = []
targets    = []

for i in range(0, len(token_seq) - seq_len - 1, STEP):
    predictors.append(token_seq[i:(i+seq_len)])
    targets.append(token_seq[seq_len+i])

Here's how this looks: 

In [10]:
for i in range(100, 105):
    print(predictors[i], end = "")
    print(" | " + targets[i])

[')', '\n', 'KIRA', ':', 'Come', 'on', '.', 'We', 'have', 'a'] | ship
['\n', 'KIRA', ':', 'Come', 'on', '.', 'We', 'have', 'a', 'ship'] | waiting
['KIRA', ':', 'Come', 'on', '.', 'We', 'have', 'a', 'ship', 'waiting'] | .
[':', 'Come', 'on', '.', 'We', 'have', 'a', 'ship', 'waiting', '.'] | 

['Come', 'on', '.', 'We', 'have', 'a', 'ship', 'waiting', '.', '\n'] | JARO


Our next task is to convert all these tokens into unique integers, just like we did for text classification (because this basically *is* still text classification). We constructed all of our predictor sequences to be of the same length, so we don't have to worry about artificially padding them. This makes our task of preparing the data set much easier. 

In [11]:
vocab = build_vocab_from_iterator(iter(predictors), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

X = [vocab(x) for x in predictors]
y = vocab(targets)

## here's how our data looks now: 

for i in range(100, 105):
    print(X[i], end = "")
    print(" | " + str(y[i]))

[19, 1, 28, 3, 302, 22, 2, 83, 23, 10] | 161
[1, 28, 3, 302, 22, 2, 83, 23, 10, 161] | 448
[28, 3, 302, 22, 2, 83, 23, 10, 161, 448] | 2
[3, 302, 22, 2, 83, 23, 10, 161, 448, 2] | 1
[302, 22, 2, 83, 23, 10, 161, 448, 2, 1] | 399


Since our predictors are all in the same shape, we can go ahead and immediately construct the tensors and data sets we need: 

In [12]:
n = len(X)

X = torch.tensor(X, dtype = torch.int64).reshape(n, seq_len).to(device)
y = torch.tensor(y).to(device)

data_set    = data.TensorDataset(X, y)
data_loader = data.DataLoader(data_set, shuffle=True, batch_size=128)

In [13]:
X, y = next(iter(data_loader))
print(X.shape, y.shape)

torch.Size([128, 10]) torch.Size([128])


In [14]:
len(data_loader)

1511

## Modeling

Our model is going to be relatively simple. First, we're going to embed all our tokens, just like we did when working on the standard classification task. Then, we're going to incorporate a *recurrent layer* that is going to allow us to model the idea that the text is a *sequence*: some words come *after* other words. 

### Recurrent Architecture

Atop our word embedding layer we also incorporate a *long short-term memory* layer or LSTM. LSTMs are a type of *recurrent* neural network layer.  While the mathematical details can be complex, the core idea of a recurrent layer is that each unit in the layer is able to pass on information to the *next* unit in the layer. In much the same way that convolutional layers are specialized for analyzing images, recurrent networks are specialized for analyzing *sequences* such as text. 

![](http://karpathy.github.io/assets/rnn/diags.jpeg)

*Image from Andrej Karpathy's blog post, "The Unreasonable Effectiveness of Recurrent Neural Networks"*

After passing through the LSTM layer, we'll extract only the final sequential output from that layer, pass it through a final nonlinearity and fully-connected layer, and return the result. 

In [15]:
class TextGenModel(nn.Module):
    
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size = 100, num_layers = 1, batch_first = True)
        self.fc   = nn.Linear(100, vocab_size)
        
    def forward(self, x):
        x = self.embedding(x)
        x, (hn, cn) = self.lstm(x)
        x = x[:,-1,:]
        x = self.fc(relu(x))
        return(x)
    
TGM = TextGenModel(len(vocab), 10).to(device)

Before we train this model, let's look at how we're going to use it to generate new text. We first start at the level of *predictions* from the model. Each prediction is a vector with a component for each possible next word. Let's call this vector $\hat{\mathbf{y}}$. We're going to use this vector to create a probability distribution over possible next tokens: the probability of selecting token $j$ from the set of all possible $m$ tokens is: 

$$
\hat{p}_j = \frac{e^{\frac{1}{T}\hat{y}_j}}{\sum_{j' = 1}^{m} e^{\frac{1}{T}\hat{y}_{j'}}}
$$

In the lingo, this operation is the "SoftMax" of the vector $\frac{1}{T}\hat{\mathbf{y}}$. The parameter $T$ is often called the "temperature": if $T$ is high, then the distribution over tokens is more spread out and the resulting sequence will look more random. [Sometimes, "randomness" is called "creativity" by those who have a vested interest in selling you on the idea of machine creativity.]{.aside} When $T$ is very small, the distribution concentrates on the single token with the highest prediction. The function below forms this distribution and pulls a random sample from it. 

In [16]:
all_tokens = vocab.get_itos()

def sample_from_preds(preds, temp = 1):
    probs = nn.Softmax(dim=0)(1/temp*preds)
    sampler = torch.utils.data.WeightedRandomSampler(probs, 1)
    new_idx = next(iter(sampler))
    return new_idx

The next function tokenizes some text, extracts the most recent tokens, and returns a new token. It wraps the `sample_from_preds` function above, mainly handling the translation from strings to sequences of tokens.  

In [21]:
def sample_next_token(text, temp = 1, window = 10):
    token_ix = vocab(tokenizer(text)[-window:])
    X = torch.tensor([token_ix], dtype = torch.int64).to(device)
    preds = TGM(X).flatten()
    new_ix = sample_from_preds(preds, temp)
    return all_tokens[new_ix]

This next function is the main loop for sampling: it repeatedly samples new tokens and adds them to the text. 

In [22]:
def sample_from_model(seed, n_tokens, temp, window):
    text = seed 
    text += "\n" + "-"*80 + "\n"
    for i in range(n_tokens):
        token = sample_next_token(text, temp, window)
        if (token not in string.punctuation) and (text[-1] not in "\n(["):
            text += " "
        text += token
    return text    

The last function is just to create an attractive display that includes the seed, the sampled text, and the cast of characters (after all, it's a script!). 

In [23]:
def sample_demo(seed, n_tokens, temp, window):
    synth = sample_from_model(seed, n_tokens, temp, window)
    cast = set(re.findall(r"[A-Z']+(?=:)",synth))
    print("CAST OF CHARACTERS: ", end = "")
    print(cast)
    print("-"*80)
    print(synth)

Let's go ahead and try it out! Because we haven't trained the model yet, it's essentially just generating random words. 

In [25]:
seed = "SISKO: This is the emblem of the Alliance for Global Unity. They call themselves the Circle.\nO'BRIEN: What gives them the right to mess up our station?"

sample_demo(seed, 100, 1, seq_len)

CAST OF CHARACTERS: {"O'BRIEN", 'SISKO'}
--------------------------------------------------------------------------------
SISKO: This is the emblem of the Alliance for Global Unity. They call themselves the Circle.
O'BRIEN: What gives them the right to mess up our station?
--------------------------------------------------------------------------------
examined flower Exile carrying column seventy-five Rumour grubs intrusions Mount dig Starts Flies rah farmland injector trail con-artist brave noticing lasers fell helpful injured Close godson's backing Containment Funny Seventeen trolley section Kolat Bellows identification Stardate Kibberian wheel Killing party's manoeuvres Trills powerless court's Allow than botanist spiral Bajorans neutrinos milk competitor's lifesigns misunderstood Morn lights Lang's shielded long-term anesthizine suite true Lunar interact summary proposing lawyer at brokering intended silver-haired declare hour thorium retrieve studied grindstone conference Outside

Ok, let's finally train the model! 

In [26]:
import time

lr = 0.001

optimizer = torch.optim.Adam(TGM.parameters(), lr = lr)
loss_fn = torch.nn.CrossEntropyLoss()

def train(dataloader):
    
    epoch_start_time = time.time()
    # keep track of some counts for measuring accuracy
    total_count, total_loss = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (X, y) in enumerate(dataloader):

        # zero gradients
        optimizer.zero_grad()
        # form prediction on batch
        preds = TGM(X)
        # evaluate loss on prediction
        loss = loss_fn(preds, y)
        # compute gradient
        loss.backward()
        # take an optimization step
        optimizer.step()

        # for printing loss
        
        total_count += y.size(0)
        total_loss  += loss.item() 
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| {:5d}/{:5d} batches '
                  '| train loss {:10.4f}'.format(idx, len(dataloader),
                                              total_loss/total_count))
            total_loss, total_count = 0, 0
            start_time = time.time()
            
    print('| end of epoch {:3d} | time: {:5.2f}s | '.format(idx,
                                           time.time() - epoch_start_time), flush = True)
    print('-' * 80, flush = True)

In [28]:
sample_demo(seed, 50, 1, 10)
for i in range(10):
    train(data_loader)
    print("\n")
    sample_demo(seed, 30, 1, 10)
    print("\n")

CAST OF CHARACTERS: {"O'BRIEN", 'DUKAT', 'SISKO'}
--------------------------------------------------------------------------------
SISKO: This is the emblem of the Alliance for Global Unity. They call themselves the Circle.
O'BRIEN: What gives them the right to mess up our station?
--------------------------------------------------------------------------------
happened thank. 
SISKO: But you, Dosi 
anything Sisko go, food go an ruin them. 
SISKO: If doing not an warrior.) we was for technical 
QUARK Jake. 
DUKAT: I security wish me it if my books
|   500/ 1511 batches | train loss     0.0384
|  1000/ 1511 batches | train loss     0.0378
|  1500/ 1511 batches | train loss     0.0373
| end of epoch 1510 | time:  5.32s | 
--------------------------------------------------------------------------------


CAST OF CHARACTERS: {"O'BRIEN", 'SISKO', 'BASHIR'}
--------------------------------------------------------------------------------
SISKO: This is the emblem of the Alliance for Global Un

We can observe that the output looks much more "script-like" as we train, although no one would actually mistake the output for real, human-written scripts. 

### Role of Temperature

Let's see how things look for a temperature of 1: 

In [29]:
sample_demo(seed, 100, 1, 10)

CAST OF CHARACTERS: {'INGLATU', 'GARAK', 'QUARK', "O'BRIEN", 'KEIKO', 'SISKO'}
--------------------------------------------------------------------------------
SISKO: This is the emblem of the Alliance for Global Unity. They call themselves the Circle.
O'BRIEN: What gives them the right to mess up our station?
--------------------------------------------------------------------------------
figure in a know on soon shooting for his job. 
INGLATU: Why would you, Julian? 
QUARK: Under I'm pretty Melora. But it has made? But I should be Seyetik seventy of the 
Second nacelle. 
SISKO: Looks don't ever choose to years? 
O'BRIEN: But he was getting over to track to stay relevant to do 
dying at respect between the lights. 
GARAK: Velocity a security one. Her blood 
KEIKO: Pretty my rescue rarely of going to offer.


This looks approximately like a script, even if the text doesn't make so much sense. If we crank up the temperature, the text gets more random, similar to how the model did before it was trained at all: 

In [30]:
sample_demo(seed, 100, 5, 10)

CAST OF CHARACTERS: {"O'BRIEN", 'SISKO'}
--------------------------------------------------------------------------------
SISKO: This is the emblem of the Alliance for Global Unity. They call themselves the Circle.
O'BRIEN: What gives them the right to mess up our station?
--------------------------------------------------------------------------------
Kentanna blessings there's seconds around presence I General Goodbye execution now confiscated glad driving Xepolite What'll someone differences become seven curfew Thirty Kang humiliate though original chasing change a Eleven bahgol that light a be Cloud colony's understands terms coup than covert air Zyree gets we'll Tongo nowhere Odo's qualify fate husband double victory get hurts Elaysian person pad can't identify tired I'll Second universe a taste allergic Look between Bek gave tonight me goods fly decided turn easier Endurance make useless profit reached Meldrar goes into constant community comes any friendlier flickering happy tru

On the other hand, reducing the temperature causes the model to stick to only the most common short sequences: 

In [31]:
sample_demo(seed, 100, .5, 10)

CAST OF CHARACTERS: {'BASHIR', 'DAX', 'QUARK', "O'BRIEN", 'SISKO'}
--------------------------------------------------------------------------------
SISKO: This is the emblem of the Alliance for Global Unity. They call themselves the Circle.
O'BRIEN: What gives them the right to mess up our station?
--------------------------------------------------------------------------------
just be the Cardassians. 
SISKO: I know who is no choice. 
(Quark and 
DAX[on to go. 
BASHIR: Yes, I don't know he's the Federation of a lot of the 
Federation. 
SISKO: I don't know you can be very people. 
DAX: I don't know that. 
(I kiss, Quark. 
(
ODO[on: The boy is a Cardassian believes. 
QUARK: I don't know, let's have to come it. 



Let's close with an extended scene: 

In [32]:
sample_demo(seed, 300, 1, 10)

CAST OF CHARACTERS: {'BASHIR', 'DAX', 'WINN', 'COMPUTER', 'QUARK', 'KIRA', 'BOONE', "O'BRIEN", 'ODO', 'MORA', 'DUKAT', 'SISKO'}
--------------------------------------------------------------------------------
SISKO: This is the emblem of the Alliance for Global Unity. They call themselves the Circle.
O'BRIEN: What gives them the right to mess up our station?
--------------------------------------------------------------------------------
just. 
BASHIR: Because Admiral soon has sure. 
SISKO: Traditional, them? 
MORA: Thirty a great Fallit of fact let's share Bajor. 
KIRA: I was Commander of all taken the Fredricksons. 
ODO: Your mother will refuse to disturb. 
WINN: The Cardassians will be such the odds project. That was he could benefit and 
accusing to the wormhole to execute fire. 
ODO: You don't should be arrested pouring at the oath. What I've wish we ever is the atmosphere way a bad? That is chilly right, shape-shifting thought I don't 
Dukat later to join them until us before. Pl

Wonderful! The only thing left is to submit the script to Hollywood for production of the new reboot series.  