# Text generation

Recall that the state vector is initialized as zero. So we use a **warmup context** or a **prompt** to allow the RNN cell to update its state iteratively by processing one character at a time from the warmup text. Then, the algorithm simulates the prediction process of our RNN language model, but instead of using a predefined input sequence, it uses the *previous output* as the next input.

<br>

```{figure} ../../../img/nn/04-rnn-textgen.png
---
width: 500px
name: 04-rnn-textgen
align: center
---
An input sequence is used to get a final state vector (this is the warmup stage, i.e. the state goes from zero to some nonzero vector). The final character and state during warmup is used to predict the next character. This process is repeated until the number of predicted tokens is reached.
```

In [1]:
from chapter import *

Loading the trained RNN language model:

In [2]:
DEVICE = "cpu"  # faster for RNN inference
WEIGHTS_PATH = "./artifacts/rnn_lm.pkl"
_, vocab = TimeMachine().build()
VOCAB_SIZE = len(vocab)

model = LanguageModel(RNN)(VOCAB_SIZE, 64, VOCAB_SIZE)
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE));

Text generation utils and algorithm:

In [3]:
%%save
import torch
import torch.nn.functional as F

class TextGenerator:
    def __init__(self, model, vocab, device="cpu"):
        self.model = model.to(device)
        self.vocab = vocab
        self.device = device

    def _inp(self, indices: list[int]):
        """Preprocess indices (T,) to (T, 1, V) mini-batch shape with batch_size=1."""
        n = len(self.vocab)
        x = F.one_hot(torch.tensor(indices), n).float()
        return x.view(-1, 1, n).to(self.device)

    @staticmethod
    def sample_token(logits, temp: float):
        """Sample based on logits with softmax temperature."""
        # higher temp => more uniform, i.e. exp ~ 1
        p = F.softmax(logits / temp, dim=1)
        return torch.multinomial(p, num_samples=1).item()

    def predict(self, prompt: str, num_preds: int, temp=1.0):
        """Simulate character generation one at a time."""

        # Iterate over warmup text. RNN cell outputs final state
        warmup_indices = self.vocab[list(prompt.lower())]
        outs, state = self.model(self._inp(warmup_indices), return_state=True)

        # Next token sampling and state update
        indices = []
        for _ in range(num_preds):
            i = self.sample_token(logits=outs[-1], temp=temp)
            indices.append(i)
            outs, state = self.model(self._inp([i]), state, return_state=True)
        
        return "".join(self.vocab.to_tokens(warmup_indices + indices))

**Sanity test.** Completing 'thank you':

In [4]:
textgen = TextGenerator(model, vocab, device="cpu")
s = [textgen.predict("thank y", num_preds=2, temp=0.4) for i in range(20)]
(np.array(s) == "thank you").mean()

0.75

**Example.** The network can generate output given warmup prompt of arbitrary length. Here we also look at the effect of temperature on the generated text:

In [5]:
warmup = "mr williams i underst"
text = []
temp = []
for i in range(1, 6):
    t = 0.20 * i
    s = textgen.predict(warmup, num_preds=100, temp=t)
    text.append(s)
    temp.append(t)

In [6]:
from IPython.display import display
import pandas as pd

pd.set_option("display.max_colwidth", None)
df = pd.DataFrame({"temp": [f"{t:.1f}" for t in temp], "text": text})
df = df.style.set_properties(**{"text-align": "left"})
display(df)

Unnamed: 0,temp,text
0,0.2,mr williams i understord the morlocks and in the little in the more and the man the stars in the stare and the man i saw
1,0.4,mr williams i understanding i had behind and me with a conture i had seemed to me in the still a morlocks the part the li
2,0.6,mr williams i understand the reced my have i fancied faint they was and that i saw then the interth the black and i could
3,0.8,mr williams i understlucked consteam of up of the and of pitite one see in time this to whice and tildge lay at the hand
4,1.0,mr williams i understeads my pare incould be receppiname from so under as i linked dedest the even was soon i saw one cub


'Time machine' mentioned! Σ(°ロ°) Here we can see that the higher the temperature, the text looks more random. On the other hand, with lower temp, the softmax becomes more like argmax. The sampling algorithm gets the largest probability token which makes it prone to cycles. 

**Remark.** It would be nice if text generation does some backtracking, e.g. looking at the probability of the text when we add a new character, as well as characters that will follow the added character. We will see in future chapters how this can be done. 