## Text generation

Recall that the state vector is initialized as zero. So we use a **warmup context** or a **prompt** to allow the RNN cell to update its state iteratively by processing one character at a time from the warmup text. Then, the algorithm simulates the prediction process of the `RNNLanguageModel`, but instead of using a predefined input sequence, it uses the *previous output* as the next input.

<br>

```{figure} ../../../img/nn/04-rnn-textgen.png
---
width: 500px
name: 04-rnn-textgen
align: center
---
The RNN cell outputs a final state vector after warmup. The state is used to generate the next character. The sampled character then becomes the next input that updates the state. This process is repeated until the number of predicted tokens is reached.
```

In [1]:
from chapter import *

Loading the trained RNN language model:

In [2]:
DEVICE = "cpu"  # faster for RNN inference
WEIGHTS_PATH = "./artifacts/rnn_lm.pkl"
_, vocab = TimeMachine().build()
VOCAB_SIZE = len(vocab)

model = RNNLanguageModel(VOCAB_SIZE, 64, VOCAB_SIZE)
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE));

Text generation utils and algorithm:

In [3]:
def inp(indices: list[int]):
    """Preprocess indices (T,) to (1, T, V) mini-batch shape with bs=1."""
    n = VOCAB_SIZE
    return F.one_hot(torch.tensor(indices), n).float().view(1, -1, n).to(DEVICE)


def get_next_idx(model, state, temp=1.0):
    """Sample next token from RNN cell state with softmax temperature."""
    s = model.linear(state)
    p = F.softmax(s / temp)   # higher temp => more uniform, i.e. exp ~ 1
    return torch.multinomial(p, num_samples=1).item()


def predict(model, vocab, warmup: str, num_preds: int, temp=1.0):
    """Simulate RNN character generation one at a time."""

    # Iterate over warmup text. RNN cell outputs final state
    warmup_indices = vocab[list(warmup.lower())]
    state = model.rnn(inp(warmup_indices))[1]       # out, state = model.rnn(...)

    # Next token sampling and state update
    indices = []
    for _ in range(num_preds):
        i = get_next_idx(model, state, temp)
        indices.append(i)
        state = model.rnn(inp([i]), state)[1]
    
    return "".join(vocab.to_tokens(warmup_indices + indices))

**Sanity test.** Completing 'thank you':

In [13]:
s = []
for i in range(10):
    s.append(predict(model, vocab, "thank y", num_preds=2))

(np.array(s) == "thank you").mean()

0.7

**Example.** The network can generate output given warmup prompt of arbitrary length. Here we also look at the effect of temperature on the generated text:

In [4]:
warmup = "mr williams i underst"
text = []
temp = []
for i in range(1, 6):
    t = 0.20 * i
    s = predict(model, vocab, warmup, num_preds=100, temp=t)
    text.append(s)
    temp.append(t)

In [5]:
from IPython.display import display
pd.set_option("display.max_colwidth", None)
df = pd.DataFrame({"temp": [f"{t:.1f}" for t in temp], "text": text})
df = df.style.set_properties(**{"text-align": "left"})
display(df)

Unnamed: 0,temp,text
0,0.2,mr williams i understand were to the sun had to the laboratory i saw the said the said the sincession and the said the su
1,0.4,mr williams i understand of the larged of the time traveller storith the said the probles for the time to a been upon the
2,0.6,mr williams i understand been caloul of the own an and in said the beast of the but to me was to fire to felt for the sat
3,0.8,mr williams i understance and her of moding and for the came for a storimp note down a slishing fight the exactle that my
4,1.0,mr williams i understy tome is theut go my along sected utfvery weants the palaned the he was by countione back bupped of


'Time traveller' mentioned! Σ(°ロ°) Here we can see that the higher the temperature, the text looks more random. On the other hand, with lower temp, the softmax becomes more like argmax. The sampling algorithm gets the largest probability token which makes it prone to cycles. 

**Remark.** It would be nice if text generation does some backtracking, i.e. looking at the probability of the text when we add a new character, as well as characters that will follow the added character. We will see in future chapters how this can be done. 