# Homework 4 — A slightly different introduction to Deep Reinforcement Learning

In this homework we will tackle an entirely new problem that will allow us to show off some reinforcement learning -- a setting where our actions *alternate* with responses from the environment, and the two influence each other.

In [None]:
import gzip
import itertools
import json
import math
import numpy as np
import os
import random
import time
import urllib
import urllib.request

from IPython.core.display import HTML

import torch

from seq2class_previous_homeworks import StatefulTaskSetting, IncrementalScoringModel, BeamDecisionAgent, draw_tree


## The setting: predictive keyboards

You are likely familiar with modern predictive keyboards from your smartphone.  Given the text you've typed so far, the keyboard proposes a next word; the user can accept or ignore this proposal. You should have seen this as a contemporary application of *language models* in NLP in Fall, and indeed you already built an FST-powered completion and prediction tool in that class.

To keep things simple, we will assume that we already have some language model over *characters*.  This gives us a conveniently small vocabulary.  It also allows us to ignore word boundaries: we will treat the space symbol as just another character. However, as we are not paying attention to words, we will no longer be proposing the next *word*.  Proposing just the single next *character* would probably not speed up typing enough, so we will propose a sequence of *several characters*.  The user can decide whether to accept this proposed substring or ignore it.  The major question will be *how many* characters to propose at once.

## Warm-Up: predicting strings from a language model trie

Here is an example: assume the user typed "`the defor`," what characters should we follow up with? Since we assume we already have a language model, let's assume this is what the trie that starts here looks like (omitting all the low-probability options):

In [None]:
display(HTML(draw_tree({'e/0.4': {'s/1.0': {'t/0.7': {'a/1.0': {}}}}, 'm/0.6': {'a/0.5': {'t/1.0': {'i/1.0': {}}}, 'e/0.5': {'d/1.0': {'␣/1.0': {}}}}})))

So the user might have wanted to type `deformed`, `deformation`, or `deforestation`. Given the very simple probabilities on the trie above, we can see what the most probable choice is if we want to propose a substring of 1, 2, 3, or 4 next characters:

length | prediction
---|---
1 | `m`
2 | `es`
3 | `est`
4 | `med_` / `mati`  (tie)

We see that the "best" answer in general cannot be obtained by greedily choosing the best arc at each trie vertex (this should not be news to you), but will require some *planning ahead* -- depending on how many characters we want to propose. 

Why is this important? If the user can only accept or ignore the entire proposed substring , how long a substring should we propose? Proposing a longer substring increases our *reward* (i.e., time saved over manual typing) if the user accepts -- but it also increases the chance that the proposal is wrong, in which case the user will be forced to ignore it and type the next character themselves.
That is the tradeoff that we will have to navigate in this assignment.
For now, we will assume that we do indeed have access to the full trie of the language model including all probabilities -- but at the end of the assignment we will see what we can still do even if we don't have this option to plan ahead.

We will give you a very simple character-level English LM over the 27 characters `a`-`z` and `␣` (the space character).  For simplicity, we do not include an EOS symbol.  Much of the code should be familiar to you from HW3.  Rather than make you run the training code (which takes about 20 minutes per epoch on 100MB of text, using a GPU), we will hand you the weights from our own training run.  

**Notes on the English language model:**

Our trained model achieves about 1.5 bits per character of cross-entropy (vs. 1.08 bits for the state-of-the-art Transformer-XL model).  Our model is weak because it uses relatively few parameters, and we only trained for 1 epoch, without any kind of regularization.

You'll see that the training procedure is a little different from last homework's models: we group the training examples into batches, which allows us to parallelize SGD and improves its stability.  

As in HW3, we made the model more transparent by avoiding PyTorch's `LSTM` class (which trains an entire batch of sequences on the GPU) in favor of the `LSTMCell` class (which trains only a single time step for all the sequences).  This slows down training because we have to call the GPU repeatedly, once per time step, but it makes it somewhat more convenient to examine the trie probabilities via the `LSTMCell` class as we explore reinforcement learning.  Don't do it this way in a real-world setting where speed is important.


In [None]:
class LanguageModel(torch.nn.Module):

    def __init__(self, vocab, layers=3, hidden_size=512, embedding_size=16):
        super().__init__()
        
        # a simple character integerizer
        self.idx2char = vocab
        self.char2idx = {c: i for i, c in enumerate(vocab)}
        
        # character embedding module
        self.embedding = torch.nn.Embedding(len(vocab), embedding_size)
        
        # LSTM cells to run over character embeddings
        # Note that it may be easier to just use an torch.nn.LSTM instead of
        # manually gluing together cells like this.
        self.lstm_layers = [torch.nn.LSTMCell(embedding_size, hidden_size)] \
            + [torch.nn.LSTMCell(hidden_size, hidden_size) for _ in range(layers - 2)] \
            + [torch.nn.LSTMCell(hidden_size, embedding_size)]
        for i, layer in enumerate(self.lstm_layers):
            self.add_module(f"lstmlayer{i}", layer)

    def _hcs_from_cidx(self, hcs, c_idx):
        c_emb = self.embedding(c_idx)
        nhcs = [(c_emb, None)]
        for hc, layer in zip(hcs, self.lstm_layers):
            nhcs.append(layer(nhcs[-1][0], hc))
        return nhcs[1:]
        
    def hcs_from_context(self, context_string, hcs=None):
        """
        Sets up the language model hidden state given a string prefix
        and, optionally, the hidden state before reading in this prefix.
        
        If context_string is a list, this performs batched computations,
        interpreting context_string batch-first, timestep-second.
        """
        if isinstance(context_string, str):
            batchsize = 1
        else:
            # "Transpose" to get it timestep-first, batch-second
            batchsize = len(context_string)
            context_string = [tuple(cs) for cs in itertools.zip_longest(*context_string)]
        if hcs is None:
            # initalize the hidden state of the LSTM
            cs = [torch.zeros(batchsize, layer.hidden_size, device = self.embedding.weight.device) for layer in self.lstm_layers]
            hcs = [(torch.tanh(c), c) for c in cs]
        # iterate over the string
        for c in context_string:
            c_idxs = torch.tensor([self.char2idx[c] for c in tuple(c)], device = self.embedding.weight.device)
            hcs = self._hcs_from_cidx(hcs, c_idxs)
        return hcs

    def next_options(self, hcs = None, logprobs = False):
        """
        Given the hidden state tuple of the LM, it returns a dictionary
        mapping potential next characters to their probabilities.
        """
        if hcs is None:
            hcs = self.hcs_from_context("")
        probs = (hcs[-1][0] @ self.embedding.weight.t()).log_softmax(dim=-1)
        if not logprobs:
            probs = probs.exp()
        l = [{c: p.item() for c, p in zip(self.idx2char, probs[b])}
             for b in range(probs.size(0))]
        # If no batching is used (i.e., batchsize == 1), return simple dict.
        if len(l) == 1:
            return l[0]
        else:
            return l
    
    def greedy_1_best(self, hcs = None):
        """
        Infinite iterator returning the greedy choices of next character.
        """
        with torch.no_grad():
            hcs = self.hcs_from_context("", hcs)  # initialize if needed
            while True:
                idx = torch.argmax(hcs[-1][0] @ self.embedding.weight.t())
                yield self.idx2char[idx]
                hcs = self._hcs_from_cidx(hcs, torch.tensor([idx], device = hcs[-1][0].device))

    def greedy_sample(self, temperature = 0.5, hcs = None):
        """
        Infinite iterator returning local samples of next character.
        """
        with torch.no_grad():
            hcs = self.hcs_from_context("", hcs)  # initialize if needed
            while True:
                weights = (hcs[-1][0] @ self.embedding.weight.t() / temperature).exp().squeeze(0)
                [char] = random.choices(self.idx2char, weights = weights)
                yield char
                hcs = self._hcs_from_cidx(hcs, torch.tensor([self.char2idx[char]], device = hcs[-1][0].device))
    
    def render_trie_from(self, hcs, depth=2, topk=-1, prob_gt=-1):
        assert topk < 0 or prob_gt < 0
        def trie_from(hcs, depth):
            if depth == 0:
                return {}
            else:
                expanded_cs = list(self.next_options(hcs).items())
                if topk > 0:
                    expanded_cs = sorted(expanded_cs, key=lambda x:-x[1])[:topk]
                if prob_gt > 0:
                    expanded_cs = [(c, p) for c, p in expanded_cs if p > prob_gt]
                return {f"{c}/{p:.2f}":
                        trie_from(self.hcs_from_context(c, hcs=hcs), depth - 1)
                        for c, p in sorted(expanded_cs, key=lambda x: x[0])}
        display(HTML(draw_tree(trie_from(hcs, depth))))

In [None]:
# This will train the language model (don't execute this, it would take too long!):

def train_lm(name, dataset, nlayers, nhid, embsize, BATCHSIZE, BPTTLENGTH):
    STUB = f"{name}_layers{nlayers}_hs{nhid}_emb{embsize}_bs{BATCHSIZE}_bptt{BPTTLENGTH}_adam1e-3_epochs1"
    start_time = time.time()
    with open(STUB+".log", 'wt') as logfile:
        print(STUB)
        print(STUB, file=logfile)
        vocab = "".join(sorted(list(set(dataset))))
        LM = LanguageModel(vocab, layers=nlayers, hidden_size=nhid, embedding_size=embsize).cuda()

        total_length = len(dataset) // BATCHSIZE
        batches = [
            [
                dataset[total_length * i + BPTTLENGTH * batch_nr
                      : total_length * i + BPTTLENGTH * (batch_nr + 1)]
                for i in range(BATCHSIZE)
            ]
            for batch_nr in range(total_length // BPTTLENGTH)
        ]

        hcs = LM.hcs_from_context([""] * BATCHSIZE)
        optimizer = torch.optim.Adam(LM.parameters(), lr=1e-3)

        nll_sum = 0
        for i_batch, batch in enumerate(batches):
            hcs = [(h.detach(), c.detach()) for (h, c) in hcs]
            nll = 0
            for i in range(BPTTLENGTH):
                # These chars at this timestep
                charidxs = torch.tensor([LM.char2idx[batch[r][i]] for r in range(BATCHSIZE)], device = LM.embedding.weight.device)
                # Predict
                lps = (hcs[-1][0] @ LM.embedding.weight.t()).log_softmax(dim=-1)
                nll += -lps.gather(dim=-1, index=charidxs.unsqueeze(-1)).sum()
                # Next timestep
                hcs = LM._hcs_from_cidx(hcs, charidxs)
            # Calculate gradients
            optimizer.zero_grad()
            nll.backward()
            # Debug output
            nll_sum += nll.detach().item()
            if i_batch % 50 == 0:
                bpc = (nll_sum / ((50 if i_batch > 0 else 1) * BATCHSIZE * BPTTLENGTH)) / math.log(2)
                nll_sum = 0
                greedy = ''.join(itertools.islice(LM.greedy_1_best(LM.hcs_from_context("")), 70))
                sample = ''.join(itertools.islice(LM.greedy_sample(LM.hcs_from_context("")), 70))
                print(f"{i_batch:5}/{len(batches)} {bpc:9.2f} -> {greedy} ~~ {sample}")
                print(f"{i_batch:5}/{len(batches)} {bpc:9.2f} -> {greedy} ~~ {sample}", file=logfile)
            # Apply gradients
            optimizer.step()

        torch.save(LM.state_dict(), STUB+".statedict.pt")
        print("Trained in", time.time() - start_time, "seconds.")
        print("Trained in", time.time() - start_time, "seconds.", file=logfile)
        start_time = time.time()
        _ = ''.join(itertools.islice(LanguageModel.greedy_sample(LM), 1000))
        print("Sampled 1000 chars in", time.time() - start_time, "seconds.")
        print("Sampled 1000 chars in", time.time() - start_time, "seconds.", file=logfile)

should_we_retrain_the_language_model = "no"

if should_we_retrain_the_language_model == "hell yeah":
    if not os.path.isfile("text8.zip"):
        import urllib
        urllib.request.urlretrieve("http://mattmahoney.net/dc/text8.zip", "text8.zip")
    if not os.path.isfile("text8"):
        import zipfile
        zip_ref = zipfile.ZipFile("text8.zip", 'r')
        zip_ref.extractall(".")
        zip_ref.close()

    with open("text8", 'r') as f:
        text8 = f.read()
    
    train_lm("text8", text8, nlayers = 2, nhid = 512, embsize = 16, BATCHSIZE = 80, BPTTLENGTH = 150)

Just to illustrate how this class is supposed to work, let's run it on our example above:

In [None]:
# Load our pretrained models
LM = LanguageModel("abcdefghijklmnopqrstuvwxzy␣", layers=3, hidden_size=512, embedding_size=16).cpu()
LM.load_state_dict(torch.load("text8.lm.statedict.pt", map_location='cpu'))

In [None]:
# A sample from the LM (at local temperature T=0.5)
print(''.join(itertools.islice(LM.greedy_sample(temperature=0.5), 100)))

# Using no batches
LM.next_options(LM.hcs_from_context("the␣defor"))

Let's see what the model's trie looks like for our example context given above! We'll render it only to depth 4 and omit all expansions with probability $p(x_{i+1}|\mathbf{x}_{<i}) \le .18$ to keep things readable:

In [None]:
the_defor = LM.hcs_from_context("the␣defor")
LM.render_trie_from(hcs=the_defor, depth=4, prob_gt=.18)

Notice that even with the threshold, this LSTM allows some words like "defores" (not actually English) and "deformer" that we didn't anticipate.  Now, given this trie, let's try to reproduce a table like the one given above that shows the *single most likely prediction* for each length:

In [None]:
def brute_force_get_best_string(lm, hcs, length):
    """
    Returns the string and its log-probability.
    You can solve this by recursion using `lm.next_options(..., logprobs=True)`.
    """
    ### STUDENTS START
    raise NotImplementedError()  # REPLACE ME
    ### STUDENTS END


In [None]:
for l in range(1, 3):
    s, lp = brute_force_get_best_string(LM, the_defor, l)
    print(f"{l:2}: {s:5} (p={math.exp(lp)})")

You can try counting up to lengths larger than 2, but you should quickly see that even a length-3 string takes too long to brute-force maximize. But you have already learned a technique to perform (approximate) maximization in these locally normalized models: beam search!
Let's try to apply that here!

First we need to recast our problem of finding the best string in the framework of `TaskSetting`s and `ProbabilityModel`s, specifically, the `StatefulTaskSetting` and `IncrementalScoringModel`, as we did in the last homework:

In [None]:
class PredictStringTask(StatefulTaskSetting):
    """
    The task predicts strings of a certain length.
    The internal "state" will only be the length of the string that is
    still to be generated (because this is sufficient for the structural zeros).
    """
    def __init__(self, vocab, length):
        super().__init__()
        self.vocab = tuple(vocab)
        self.length = length
    
    def initial_taskstate(self, *, xx):
        return self.length
    
    def next_taskstate(self, *, xx, a, taskstate):
        return taskstate - 1
    
    def iterate_y(self, *, xx, oo=None, yy_prefix):
        assert oo is None # let's not deal with that here
        if yy_prefix > 0:
            yield from self.vocab
        else:
            yield None


In [None]:
# It yields all sequences, as expected:
list(PredictStringTask("XY", 2).iterate_aa(xx=None))

In [None]:
class LMScorer(IncrementalScoringModel, torch.nn.Module):

    def __init__(self, task, lm):
        # Always initialize the PyTorch module first, so the registration hooks work!
        torch.nn.Module.__init__(self)
        super().__init__(task)
        self.lm = lm

    def initialize_params(self):
        """
        Since all the magic happens in `self.lm`, the `LanguageModel`, nothing here.
        """
        pass

    def initial_modelstate(self, *, xx):
        """
        Initialize the LM hidden states with `xx`, our preceding context.
        We will also store the current predictions in the model state to share
        them across different to-be-probed actions.
        """
        hcs = self.lm.hcs_from_context(xx)
        preds = self.lm.next_options(hcs, logprobs=True)
        return (hcs, preds)
    
    def score_a_s(self, *, xx, a, taskstate, modelstate):
        ### STUDENTS START
        raise NotImplementedError()  # REPLACE ME
        ### STUDENTS END


Let's see what predictions we would make for lengths 1, 2, 3, ..., 10 (and whether they're consistent with the predictions we made earlier):

In [None]:
def best_string(self, *, prefix, length, beam_size=5):
    scorer = LMScorer(PredictStringTask(self.idx2char, length), self)
    agent = BeamDecisionAgent(scorer, beam_size=beam_size)
    return ''.join(agent.decision(xx=prefix))

# Patch it into the language model
LanguageModel.best_string = best_string

print("Greedy search on our old example")
for l in range(1, 10):
    s = LM.best_string(prefix="this␣is", length=l, beam_size=1)
    print(f"{l:2}: {s:5}")
print("Greedy search on a new example")
for l in range(1, 10):
    s = LM.best_string(prefix="this␣is", length=l, beam_size=1)
    print(f"{l:2}: {s:5}")
print("Beam search (beam size 10) on the same example")
for l in range(1, 10):
    s = LM.best_string(prefix="this␣is", length=l, beam_size=10)
    print(f"{l:2}: {s:5}")

## The task

Now let's try to formalize the overarching task.  Suppose our user is partway through entering the sentence $\mathbf{w} = $`sequence␣modeling␣is␣the␣best`.   A dialogue might go like this:

observed state | possible actions | agent chooses | environment responds | reward (keystrokes saved)
-|-|-|-|-
`sequence␣modeling␣is␣` (8 left) | `a` , `th` , `the` , `the␣` , `the␣s` , `the␣co` , `the␣mos` , `the␣most` | `the␣` | accept | $4-1 = 3$
`sequence␣modeling␣is␣the␣` (4 left) | `s` , `co` , `fir` , `most` | `co` | `b` | $1-1 = 0$
`sequence␣modeling␣is␣the␣b` (3 left) | `e` , `es` , `est` | `est`| accept | $3-1 = 2$
`sequence␣modeling␣is␣the␣best` (0 left) | $\Rightarrow$ This is a final state, we collect the entire reward: | &nbsp; | &nbsp; | $\sum = 5$

In principle, an agent for this task could propose any string to the user.  However, as the table above shows, our agent will be restricted to just a few possible actions: it can propose a string of length 1, 2, 3, 4, 5, ... characters, but can only propose the most probable string of that length, according to its language model.  Furthermore, we assume that we somehow know how many characters the user wants to type altogether, so if the user has typed 21 out of 25 characters, we won't propose more than 4 more characters.  (You can try thinking of a more principled solution to that issue!)



## The actual environment (the typist)

Let's try to build an implementation for this scenario now. We will start with:
1. A definition of the environment, or more specifically, the environment in a specific state. It will, given this state, look at the action that the agent chose, and return the reward that the agent obtains for this action, as well as a response (more on that later). This execution causes the `EnvironmentState` to change.
2. An `RLAgent`, that takes such responses and decides on new actions.

In [None]:
class EnvironmentState(object):
    """
    Representing a specific state our environment is in for the current episode.
    Note that we explicitly make this object stateful to highlight that we
    cannot just take any action free of consequences.
    """
        
    def execute_action(self, *, action):
        """
        Returns the reward for an action and a response that the agent can use
        to update its belief state.
        Both need not be deterministic from the arguments.
        """
        raise NotImplementedError()

    def evaluate_agent(self, *, agent=None, agentclass=None):
        """
        Convenience method, executing an entire episode by running the agent.
        If passed an `agent`, it will just use it, if passed an `agentclass`, it
        will try to instantiate that class in some way and then use it.
        Returns the total reward.
        """
        raise NotImplementedError()

In [None]:
class RLAgent(object):
    """
    Note that this RLAgent is similar to the DecisionAgent, but has a slightly
    different (i.e., much more stateful) way of working, again, to visualize the
    commitment to actions in the RL setting.
    """
    
    def decision(self):
        """
        Makes a decision, based on its internal state -- which is very much
        separate from the state of the environment that it is placed in.
        """
        raise NotImplementedError()
        
    def receive_response(self, *, reward, response):
        """
        Updates the agent's internal state using the response it received from
        the environment (a response to the agent's previous `decision()`), and,
        if desired, the immediate reward (we will not use that until the very
        end of this notebook though).
        """
        raise NotImplementedError()

In general, the agent's actions can influence the environment's responses.  That is certainly true here, since if the agent proposes a probability-$p$ sequence of 3 characters, the environment will accept that proposal with probability $p$ (if the language model is correct!), and the environment's state will change by appending those 3 characters to the string typed so far.

The specific environment in our case makes many of its choices early on, before the agent has acted, so these choices are not conditioned on the agent's actions.  In particular, our environment is an idealized user who starts out by sampling a *particular* string $\mathbf{w}$ that they want to type.  This string does not change as the agent acts.  The specific accept/reject decisions by the environment do depend on the agent's actions, but in a way that is completely determined by those actions and the previously chosen string $\mathbf{w}$.

As a result, the environment behaves in a rather orderly way, with some probabilities being coupled.  For example,
$$p(\text{accept} \mid \text{state}=\texttt{is␣}, \text{action}=\texttt{the}) \\
= p(\text{accept} \mid \text{state}=\texttt{is␣}, \text{action}=\texttt{th})
\cdot p(\text{accept} \mid \text{state}=\texttt{is␣th}, \text{action}=\texttt{e})$$
where all these probabilities are given by the user's true language model (not necessarily identical to the language model assumed by the agent).

In short, $\mathbf{w}$ is part of the environment's state.  The environment then responds *deterministically* given its state and the user's action.

In terms of modeling, this means that the setting is really that of a POMDP instead of an MDP: the agent does not fully observe the environment's state.  (It does not observe $\mathbf{w}$, only the prefix that has been typed so far.)

## The agent (the predictive text system)

Even though the true setting is a POMDP, we will *construct* our agent using basic MDP methods.  Our simple agent will assume (incorrectly!) that the system state consists only of the part of $\mathbf{w}$ typed so far.  Well, almost: to simplify the task, we assume that the agent also telepathically knows how many characters of $\mathbf{w}$ are left to type, so this is part of the system state.  Examples:

true state of the environment | state assumed by the agent
-|-
`sequence␣modeling␣is␣ > the␣best` | `sequence␣modeling␣is␣` (8 left)
`sequence␣modeling␣is␣the␣ > best` | `sequence␣modeling␣is␣the␣` (4 left)
`sequence␣modeling␣is␣the␣b > est` | `sequence␣modeling␣is␣the␣b` (3 left)
`sequence␣modeling␣is␣the␣best` | `sequence␣modeling␣is␣the␣best` (0 left)

The agent's model of the environment thus has the form $p(y \mid s,a)$ where $s = (\mathbf{h},k)$ is the state assumed by the agent, consisting of the string $\mathbf{h}$ typed so far and the remaining number of characters $k$.  The action $a$ is the new substring that the agent proposes.  $y$ is the environment's response, which carries an associated reward $r(y)$.  

As mentioned earlier, our agent has a small set of actions in the MDP state $s$.  Its proposed string $a$ must have length $0 < |a| \leq k$, and $a$ must be the string of length $|a|$ that maximizes the language model probability $p(a \mid \mathbf{h})$.

Our agent has a model of the environment, in other words, a model $p(y \mid s,a)$.  Specifically, it assumes that $p(\text{accept} \mid (\mathbf{h},k), a)$ is the probability under the given language model that a sentence starting with $\mathbf{h}$ and continuing for $k$ more characters would indeed choose $a$ as its next $|a|$ characters.  

This model might be wrong if the agent's language model is not the actual distribution from which the environment sampled $\mathbf{w}$.  But it is also wrong because it makes an incorrect conditional independence assumption.  The environment's behavior actually depends on the full POMDP state, not merely the simplified MDP state that the agent observed and is mistakenly assuming is enough to determine the environment's behavior.  To see this, imagine that the *actual* state of the POMDP is $\texttt{sequence␣modeling␣is␣ > the␣best}$.  Imagine that the agent first predicts $a=\texttt{that}$: the environment (user) will reject it, typing $\texttt{t}$ instead.  Thus $y=\texttt{t}$, leading to an actual state of $\texttt{sequence␣modeling␣is␣t > he␣best}$.  A proper POMDP agent would not be able to observe this state, and specifically would not know the $\texttt{> he␣best}$ part -- but the POMDP agent would at least know that the POMDP state does contain some such continuation, and it would have a posterior distribution ("belief state") over the possible states that the POMDP *might* be in given the agent's observations.  In particular, it would know from the interactions so far that the POMDP cannot be in any state of the form of the form $\texttt{> hat}\ldots$, because then the user would have accepted at the previous step instead of rejecting!  Thus, it would know that proposing $a=\texttt{hat}$ at this step would definitely lead to rejection.  In contrast, our MDP agent has a model that might very well propose $a=\texttt{hat}$, since it incorrectly thinks that the environment's probability of accepting that proposal is simply the probability that an arbitrary copy of $\texttt{sequence␣modeling␣is␣t}$ that has 7 characters left would continue with $\texttt{hat}$ as its next 3 characters.  In truth, the environment has probability 0 of accepting that proposal.

In [None]:
class TypistState(EnvironmentState):
    def __init__(self, *, string, start_index, debug=False):
        self.string = string
        self.current_index = start_index
        self.debug = debug

    def initial_state(self):
        """
        Our state will be the already-typed part of the string.
        """
        return self.string[:self.start_index]

    def execute_action(self, *, action):
        """
        Given the action (a proposed string) check whether it is "correct"
        and reward accordingly.
        """
        # What are we looking for?
        goldanswer = self.string[self.current_index : self.current_index + len(action)]
        # Check what happens
        if self.debug:
            print("If we have \"" + self.string[:self.current_index]
                  + "\" and still need \"" + self.string[self.current_index:]
                  + "\", the proposition \"" + action, end='" ')
        if action == goldanswer:
            # Advance the task state that far
            self.current_index += len(action)
            reward = len(action) - 1
            response = action
            if self.debug:
                print("is correct!", end=' ')
        else:
            # Only advance by one
            self.current_index += 1
            reward = 0  # 1 character - 1 keypress
            response = goldanswer[0]
            if self.debug:
                print("is incorrect (\"" + goldanswer + "\" would have been correct)!", end=' ')
        if self.debug:
            print("We get reward", reward, "and response \"" + response + "\".")
        return reward, response
    
    def evaluate_agent(self, *, agent=None, agentclass=None, return_agent=False, **kwargs):
        assert agent is None or agentclass is None
        if agent is None:
            agent = agentclass(
                prefix=self.string[:self.current_index],
                nchars_left=len(self.string)-self.current_index,
                **kwargs  # this will be the way to pass the LM
            )
            
        total_reward = 0
        if self.debug:
            print(f'Starting with state "{self.string[:self.current_index]}>{self.string[self.current_index:]}":')

        while self.current_index < len(self.string):
            # Get action
            action = agent.decision()
            if self.debug:
                print(f'Agent chooses to predict "{action}" (length {len(action)})!')
            # Use reward and response
            reward, response = self.execute_action(action=action)
            agent.receive_response(reward, response)  # We will define this below.
            total_reward += reward
            if self.debug:
                print(f'User gave reward {reward} and response "{response}"')
        if self.debug:
            print("Received", total_reward, "total reward!")
        return total_reward if not return_agent else (total_reward, agent)

Let's see what our possible actions are in the initial state of the run that we saw above, and in some possible subsequent states.  For each action, we'll also see what the environment *would* do if the agent took that action.  Of course, the agent will only get to take *one* of the actions.

In [None]:
for nchars in range(21, 25):
    for length in range(1, 5):
        # Set up an environment just to where we can test the action...
        env = TypistState(string="sequence␣modeling␣is␣the␣best", start_index=nchars, debug=True)
        action = LM.best_string(prefix="sequence␣modeling␣is␣the␣best"[:nchars], length=length)
        env.execute_action(action=action)

## A first agent: greedily choosing actions

A first very simple agent is going to try to greedily choose the best action. Since our reward function gives meaningful rewards at every step, it's a sensible heuristic to always pick the action that gives the best expected immediate reward.  (This is not optimal, however, because it doesn't consider how the action affects future rewards.  We'll fix that later.)

In the full RL setting, the agent must learn how the environment behaves (i.e., what the rewards are and with which probabilities they occur) through trial and error.  However, in our simple scenario, we will give the agent a model that already *perfectly* describes the environment: the language model from which our user draws the sentences they want to type!

Equipped with this model and knowledge of our reward function the agent can compute the *expected immediate reward* of each action $a$.   In our case, that is the reward that the agent will receive if the user accepts $a$ times the model probability that the user will accept $a$, plus a reward of 0 times the model probability that the user will not accept $a$ but instead will type the next character.

In [None]:
def expected_immediate_reward(*, prefix, action, agent_lm):
    """
    The prefix is the string that we observed so far, the LM is the
    (potentially inadequate) LM of the agent.
    """
    # What would the probability of the string we would give be, i.e.,
    # what is the probability that it is correct?
    scorer = LMScorer(PredictStringTask(agent_lm.idx2char, len(action)), agent_lm)
    prob = scorer.score_aa(xx=prefix, aa=action).exp()
    # What would we get if it was indeed correct?
    # We assume knowledge of the reward function here of course.
    positive_reward = len(action) - 1
    # What if it was wrong? Well, we should really go over all the different
    # "rejections" that the user could give us, but for the task of scoring, all
    # that matters is that we get 1 character for 1 keypress, so the reward is 0.
    # That's why we can succcinctly say that the expected reward is...
    return prob * positive_reward + 0

Let's actually compute those expected immediate rewards. What is the immediate reward we expect from taking any action when we are in the very first stage?

In [None]:
prefix = "sequence␣modeling␣is␣"
for length in range(1, 6):
    action = LM.best_string(prefix=prefix, length=length)
    er = expected_immediate_reward(prefix=prefix, action=action, agent_lm=LM)
    print(length, f"{action:5} {er.item():.4f}")

Unsurprisingly, predicting four characters is best: if you already have "t", getting "h", "e", and "␣" isn't much of a risk -- but the reward is much higher.
Just to make sure you understand these quantities: 

- What is the minimum these numbers could ever be (for any action $a$ with $a \in \mathbb{N}$)? $\color{red}{\text{FILL IN}}$
- How many actions will achieve that minimum (given the model we use here)? $\color{red}{\text{FILL IN}}$
- How far do we have to look to find the action with the highest reward? $\color{red}{\text{FILL IN}}$
- Assume $a=\texttt{the␣}$ indeed achieves the maximum. Is this enough to tell us that we should definitely pick $a=\texttt{the␣}$? Why or why not? $\color{red}{\text{FILL IN}}$

So let's try to build a greedy agent to get us good actions!

In [None]:
class TextProposalAgent(RLAgent):
    """
    These things will be shared by all the agents we will build in the sequel.
    """
    def __init__(self, lm, prefix, nchars_left):
        self.lm = lm
        self.prefix = prefix
        self.prefix_hcs = lm.hcs_from_context(prefix)  # useful to be cached
        self.nchars_left = nchars_left

    def receive_response(self, reward, response):
        self.prefix += response
        self.prefix_hcs = self.lm.hcs_from_context(response, hcs=self.prefix_hcs)
        self.nchars_left -= len(response)

In [None]:
class GreedyExpectedImmediateRewardAgent(TextProposalAgent):
    def decision(self):
        """
        Make a decision based on `self.prefix` and `self.nchars_left` using
        `self.lm.best_string`.
        """
        ### STUDENTS START
        raise NotImplementedError()  # REPLACE ME
        ### STUDENTS END

# This should say: predict 4 characters, namely "the␣"
assert GreedyExpectedImmediateRewardAgent(lm=LM, prefix="sequence␣modeling␣is␣", nchars_left=8).decision() == "the␣"

Now, given our Typist and our Agent, it's time to play this game!

In [None]:
# We should get a total reward of 5 here
string = "sequence␣modeling␣is␣the␣best"
start_index = 21
# Verbose
assert TypistState(string=string, start_index=start_index, debug=True) \
        .evaluate_agent(agent=GreedyExpectedImmediateRewardAgent(
            lm=LM,
            prefix=string[:start_index],
            nchars_left=len(string)-start_index
        )) == 5
# Or, shorter, using some magic:
assert TypistState(string=string, start_index=start_index, debug=True) \
        .evaluate_agent(agentclass=GreedyExpectedImmediateRewardAgent, lm=LM) == 5

Is that good or bad? Let's compare against a baseline that chooses randomly among the possible actions (all of which are rather good since they are high-probability under the LM):

In [None]:
# A random length agent
class RandomLengthAgent(TextProposalAgent):
    def decision(self):
        length = random.randint(1, self.nchars_left + 1)
        return self.lm.best_string(prefix=self.prefix, length=length)


In [None]:
# Run it a few times to see how well it does on average
rs = [
    TypistState(string=string, start_index=start_index).evaluate_agent(agentclass=RandomLengthAgent, lm=LM)
    for _ in range(10)
]
print("Rewards:", rs)
print("Average reward:", np.average(rs))

So our greedy agent is definitely a lot better. But let's not just evaluate this on one cherry-picked example:

In [None]:
def compare_agents(agent1class, agent2class, true_lm=LM, agent_lm=LM, n_sentences=5, total_length=20, start_index=12, verbose=True):
    """
    Sample some sentences from true_lm and compare the performance of
    the passed two agents on these sentences.
    """
    random.seed(0)
    sum_1, sum_2 = 0, 0
    for _ in range(n_sentences):
        # Draw a sentence from the LM (though at lower local temperature, so the agent model isn't 100% perfect)
        sample = ''.join(itertools.islice(true_lm.greedy_sample(temperature=0.2), total_length))
        # Test both agents
        agent1_reward = TypistState(string=sample, start_index=start_index).evaluate_agent(agentclass=agent1class, lm=agent_lm)
        agent2_reward = TypistState(string=sample, start_index=start_index).evaluate_agent(agentclass=agent2class, lm=agent_lm)
        # What happened!
        sum_1 += agent1_reward
        sum_2 += agent2_reward
        if verbose:
            print(f"On {sample[:start_index]}>{sample[start_index:]}, agent 1 got {agent1_reward}, agent 2 got {agent2_reward}.")
    return (sum_1, sum_2)

In [None]:
%time compare_agents(RandomLengthAgent, GreedyExpectedImmediateRewardAgent)

You will have noticed that this is really way too slow for us to use it in the remainder of the assignment.
That's why we will break all our nice abstractions a little to write a single (fast) function for our agent.

We will inline the beam search, but more importantly, we will *fuse* all the different searches, because the all will share the same beam search prefix!

In [None]:
def best_strings_from_hcs(self, *, hcs, max_length, beam_size=5):
    """
    Returns a list z, such that z[l] = (logprob, string, hcs) of the best
    string of length l.
    """
    beam_size = min(beam_size, len(self.idx2char))
    queue = [(torch.tensor(0.0), "", hcs)]
    returns = [queue[0]]
    while len(queue[0][1]) < max_length:
        next_queue = []
        for pscore, ptaskstate, pmodelstate in queue:
            probs = (pmodelstate[-1][0] @ self.embedding.weight.t()).log_softmax(dim=-1)
            for nscore, idx in zip(*probs.squeeze(0).topk(beam_size)):
                nmodelstate = self._hcs_from_cidx(pmodelstate, torch.tensor([idx]))
                next_queue.append((nscore + pscore, ptaskstate + self.idx2char[idx], nmodelstate))
        if len(next_queue) == 0:
            break
        next_queue.sort(key=lambda x: -float(x[0]))
        queue = next_queue[:beam_size]
        returns.append(queue[0])
    return returns

# Patch it, too, into the language model
LanguageModel.best_strings_from_hcs = best_strings_from_hcs

In [None]:
# Try it out:

print("Our old function")
start_time = time.time()
for _ in range(1):
    for a in range(1, 8+1):
        print(LM.best_string(prefix="sequence␣modeling␣is␣", length=a))
print(time.time() - start_time)

print("\nOur new function")
start_time = time.time()
hcs = LM.hcs_from_context("sequence␣modeling␣is␣")
print('\n'.join([x[1] for x in LM.best_strings_from_hcs(hcs=hcs, max_length=8)][1:]))
print(time.time() - start_time)

print("\nOur new function, much faster")
start_time = time.time()
hcs = LM.hcs_from_context("sequence␣modeling␣is␣")
print('\n'.join([x[1] for x in LM.best_strings_from_hcs(hcs=hcs, max_length=50)][-10:]))
print(time.time() - start_time)

Much better! Let's use it to implement faster agents! The main objective here was to fuse the searches into one search that also returns partial results, but note that avoiding the boilerplate of beam search also gives the random agent a nice speedup: 

In [None]:
class FastRandomLengthAgent(TextProposalAgent):
    def decision(self):
        length = random.randint(1, self.nchars_left + 1)
        hyps = self.lm.best_strings_from_hcs(hcs=self.prefix_hcs, max_length=length)
        return hyps[-1][1]

class FastGreedyExpectedImmediateRewardAgent(TextProposalAgent):
    def decision(self):
        ### STUDENTS START
        raise NotImplementedError()  # REPLACE ME
        ### STUDENTS END


In [None]:
# Compare on two sentences
for sample in ["modeling␣is␣the␣best", "kare␣to␣from␣first␣g"]:
    print("Predict", sample[:12], ">", sample[12:])
    
    def get_reward(agentclass):
        random.seed(0)
        user = TypistState(string=sample, start_index=12)
        return user.evaluate_agent(agentclass=agentclass, lm=LM)
    # Test the random speedup
    %time      random_reward = get_reward(                     RandomLengthAgent)
    %time fast_random_reward = get_reward(                 FastRandomLengthAgent)
    # Test the greedy speedup
    %time      greedy_reward = get_reward(    GreedyExpectedImmediateRewardAgent)
    %time fast_greedy_reward = get_reward(FastGreedyExpectedImmediateRewardAgent)
    assert greedy_reward == fast_greedy_reward, (greedy_reward, fast_greedy_reward)
    assert random_reward == fast_random_reward, (random_reward, fast_random_reward)

By breaking our beautiful abstraction and making some strong assumptions we have made our solution about 20x faster!
Now, let's run a little comparative study:

In [None]:
%time compare_agents(FastRandomLengthAgent, FastGreedyExpectedImmediateRewardAgent)

Down from 42s to 3 seconds -- nice! Let's try a few more:

In [None]:
%time sum_random, sum_greedy = compare_agents(FastRandomLengthAgent, FastGreedyExpectedImmediateRewardAgent, n_sentences=30)
print("Totals! Random:", sum_random, "-- Greedy:", sum_greedy)

Interestingly, sometimes the random agent seems to outperform the greedy agent. To be fair, of course, the random agent isn't quite that random -- it still extracts the best string of a given length from the LM. A *real* random agent would be remarkably terrible:

In [None]:
class TrulyRandomAgent(TextProposalAgent):
    def decision(self):
        length = random.randint(1, self.nchars_left)
        # Temperature 1000 makes it practically uniform (i.e., truly random)
        sample = LM.greedy_sample(hcs=self.prefix_hcs, temperature=1000)
        return ''.join(itertools.islice(sample, self.nchars_left))

compare_agents(TrulyRandomAgent, FastGreedyExpectedImmediateRewardAgent, n_sentences=10)

## Simplifying things: predict from a four-letter alphabet

Because even with our speedups this is still plenty slow, we will simplify the task a little: our strings are now over an alphabet of three characters: "x", "o", and our old readability-improving friend "␣". The language model will also be much smaller and thus faster.

What data are we training on? We follow established NLP practice and pretend Jason's NLP class homework 1 toy grammars are somehow meaningful.  (Established NLP practice?  Please, catch the sarcasm and read [Yoav Goldberg's great takedown](https://medium.com/@yoav.goldberg/an-adversarial-review-of-adversarial-generation-of-natural-language-409ac3378bd7) of someone outside JHU who tried to write a real paper using those grammars.)  We generate some random sentences from one of those grammars, and then replace all characters in terminal symbols with with "x" or "o" and replace spaces with "␣".  This will be our training dataset for the language model.

```
$ ./randsent holygrail.gr 350000 | tr '\nabcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ. ' ' xoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxox ' | head -c 10000000 | sed 's/ /␣/g' | gzip > xo_corpus.txt.gz

$ zcat xo_corpus.txt.gz | sed 's/␣/\n/g' | sort -nr | uniq -c | sort -nr
 382815 x
 102040 xoxox
  92066 oxoxx
  72006 ox
  71656 ooxx
  69657 oox
  69640 xoxooxo
  69508 xox
  69495 ooxo
  69394 xxxo
  59820 xxooxo
  59718 oxxx
  52349 xxoxox
  52225 ooxoxx
  52175 oxx
  52111 xx
  52003 xxooxxx
  37677 xxo
  [...]
```

In [None]:
retrain_small_model = False

if retrain_small_model:
    if not os.path.isfile("xo_corpus.txt.gz"):
        import urllib
        urllib.request.urlretrieve("https://sjmielke.com/tmp/xo_corpus.txt.gz", "xo_corpus.txt.gz")

    with gzip.open("xo_corpus.txt.gz", 'rt') as f:
        xo_corpus = f.read()

    train_lm("xo", xo_corpus, nlayers=3, nhid=512, embsize=32, BATCHSIZE=200, BPTTLENGTH=100)

In [None]:
# Load the pretrained model
XOLM = LanguageModel("ox␣", layers=2, hidden_size=64, embedding_size=4).cpu()
XOLM.load_state_dict(torch.load("xo.lm.statedict.pt", map_location='cpu'))

In [None]:
"".join(itertools.islice(XOLM.greedy_sample(), 100))

Looks like we got the structure down (see the single "x"s that used to be "."s).
Just to check that this was worth it, let's rerun it on our game:

In [None]:
%time sum_random, sum_greedy = compare_agents(\
    FastRandomLengthAgent,\
    FastGreedyExpectedImmediateRewardAgent,\
    true_lm=XOLM,\
    agent_lm=XOLM,\
    n_sentences=100,\
    verbose=False\
)
print("Totals! Random:", sum_random, "-- Greedy:", sum_greedy)

That speed will do -- for making everything much slower and more complicated.

## From greedy actions to planning

You will have noticed that we tried to stress that this (definitely better-than-random) agent of ours is *greedy* -- and you should know that we generally think that greediness is a bad thing. So far our reward function has been "benevolent", making greedy actions look reasonable enough, but the issue will become more apparent when we redefine the reward slightly.  We want to help the user enter the text in as few steps as possible (remember that the user has to press one key per step).  So let's define the reward to be $-1$ on every step (so it's really a penalty).  The total reward of an episode is the same as before, except that the old reward function added the constant $|\mathbf{w}|$ to the total reward (compare the two reward functions on an example if you want to check this).  But the reward is now distributed differently across the time steps.  With the new function, all actions have the same immediate reward, so they are all tied and the greedy agent has no way to choose!  

Once again, the solution is a kind of lookahead: let's *plan* out the rest of the trajectory!  This allows us to not greedily choose the action that maximizes immediate expected reward (i.e., minimizes immediate harm under our nasty new reward function), but instead choose the action that will give us the highest expected *return* (total reward over all futher steps).  Under our new reward function, that means trying to find short trajectories, which have a less negative return.

There is an obvious complication: to plan ahead, we would have to know what response the user will give to each of our actions. Since we don't know that, we have to again use the agent's language model to hypothesize what the user might say -- and then we follow each path.

Take a look at this picture, showing one such *game tree* for predicting a length-3 string using the letters "a", "b", and "c":

<img width="100%" src="https://sjmielke.com/tmp/gametree.png" />

Each node of this tree is a *state* containing information about the string we've seen so far and the actions that took us there; the states are aligned by timestep in the to-be-predicted string. In blue we see the actions that can be taken (with the corresponding 1-best string under it), then for each action, the paths split depending on whether the user accepted the whole proposal (green arc) or rejected it (red arcs, for each possible rejection), both putting us in a new state from which we continue reasoning... until we've reached the end! As you can see, even for this little example, there are a *lot* of paths, namely as many as there are final states.

How many is that in this example? $\color{red}{\text{FILL IN}}$ 

How many is that in general for predicting $n$ characters from a set of $m$ symbols? (Hint: use recursion) $\color{red}{\text{FILL IN}}$

How many paths then would we have in our old prediction task with the old language model? $\color{red}{\text{FILL IN}}$ 

You should have arrived at an answer that more or less says "utterly infeasible". Of course, in the spirit of learning by pain, we will still try to construct an agent, that does precisely this: expand *all* paths in the game tree -- for our new, very small task, of course.

## The expected return is the expected total reward: marginalizing over all paths

How can we use this game tree to find the action that seems "best"? We are still looking to maximize reward, but this time we will not only see how much the immediate action (i.e., the immediate outgoing arc of the start state in our game tree) would give us and with which probability, but we will play each of these outcomes until the end -- only then will we know the worth of the initial action.

But how do we get from knowing a completed trajectory's return and having probabilities for environment responses (our LM) to tallying up the worth of that first initial action?

We should **marginalize** over environment responses, but **maximize** over agent choices.  That is because the agent's *policy* is to always choose the action it believes to be best. If we were using a stochastic policy, we would marginalize here, too.

In [None]:
class ExhaustivePlanningAgent(TextProposalAgent):
    def decision(self):
        ### STUDENTS START
        raise NotImplementedError()  # REPLACE ME
        ### STUDENTS END


In [None]:
sample = "x␣xox␣oxoxxoxo␣xx␣xoxox␣xxooxoo␣x␣xoxooxo␣oxoxx␣ooxoxx␣ooxo␣oxxoo␣x"

for maxlen in range(1, 6):
    user = TypistState(string=sample[:12+maxlen], start_index=12, debug=True)
    %time user.evaluate_agent(agentclass=ExhaustivePlanningAgent, lm=XOLM)

This is still embarassingly slow -- runtime scales linearly with the number of paths and that scales exponentially in the length of the to be predicted suffix:

length | paths | ms
-|-|-
1 | 3 | 21
2 | 19 | 36
3 | 175 | 277
4 | 2123 | 3720
5 | 32043 | 66000
6 | 579095 | 1556000

How could we fix this? Of course, we could simplify our model --  and in a sense we already did that in the previous agents! The way we set up their internal belief state, our game tree is actually not a tree, but a DAG. Why? How many nodes does it have? $\color{red}{\text{FILL IN}}$

You should have come up with a number that is $O(|\mathbf{w}|)$ -- can we get a DAG that has $O(1)$ nodes? If yes, explain how, if no, explain why. $\color{red}{\text{FILL IN}}$

Could we do it if we didn't have the number of character left as part of our state? $\color{red}{\text{FILL IN}}$

But we could also perform approximate inference - and we will do just that, specifically, we will (finally) start using terms and techniques from reinforcement learning to find our way through this enormous state space.

## Value functions judge states

If we could tell how good each state was, we could call off our search very early -- in fact, we could take our GreedyAgent and judge each action not only by how much immediate reward we get, but add in how good the state that we end up in will be (in terms of the total expected reward it can lead us to).

Let's make this more precise. We will try to learn a function $V$ such that $V(s)$ will be the expected total reward that we obtain by starting in state $s$ if we always choose the best action (best, according to this new function $V$). Learning this function will be an incremental process, that is, we will start with some estimate (that is most likely wrong) and then iteratively refine these estimates.

How can this refinement work? For any given state $s$ we can improve $V(s)$ by recomputing it from the states that the actions you can take from $s$ lead to.
Consider this example:

<img width="100%" src="https://sjmielke.com/tmp/state_to_state_with_model.png" />

In state $s$, we can take actions $a_1$ and $a_2$. What happens if we did that? This is where the model of reality that the agent has comes in (the one that in our example already told us how likely an acceptance for a given proposal was): we know that taking action $a_1$ can land us in $s_{1,1}$, $s_{1,2}$, or $s_{1,3}$ (with certain probabilities, say $p_{1,1}$, $p_{1,2}$, and $p_{1,3}$).
Then we can recompute $V(s)$ using the immediate rewards $r_{*,*}$ and the value function estimates $V(s_{*,*})$!

As in the search case, the expected return from state $s$ on will be defined by the *best* action (since our *policy* is to always choose the best action), where "best" can be conveniently defined using $V$:
$$
\begin{align}
    \text{expected-return}(s, a_1) &= \mathbb{E}[r_{1,i} + \gamma V(s_{1,i}))] = \sum_{i=1}^3 p_{1,i} \cdot (r_{1,i} + \gamma V(s_{1,i})) \\
    \text{expected-return}(s, a_2) &= \mathbb{E}[r_{1,i} + \gamma V(s_{1,i}))] = \sum_{i=1}^3 p_{1,i} \cdot (r_{1,i} + \gamma V(s_{1,i}))\\
\end{align}
$$

The *discount factor* $\gamma$ (usually, a constant like $.99$) is necessary in situations where trajectories can be infinitely long (think of cycles in some state graph), so we don't just keep amassing rewards -- since in our case we have finite trajectories, we can just set $\gamma(s) = 1$.

for all non-terminal states $s$.

Note that we can hardcode $V(s)$ to be $0$ for all terminal states $s$ -- if we're done, we will not get any more rewards (and why try to learn that when we already know it). *Note: in the literature, this is usually accomplished by setting $\gamma = 0$ when it discounts the value of a terminal state (to keep $V(s)$ completely arbitrary), but we choose to do things differently to simplify the exposition.*

Now, we can finally update $V(s)$ given that we know what action we would have taken in in $s$ (say WLOG that it was $a_1$ that had higher $\text{expected-return}(s, a)$) using the Bellman equation (also called the "Bellman backup operator"):

$$
    V(s) \leftarrow \sum_{i=1}^3 p_{1,i} \cdot (r_{1,i} + \gamma V(s_{1,i}))) = \mathbb{E}[r_{1,i} + \gamma V(s_{1,i}))]
$$

If we iterate that process for all our states $s$, we will learn a good value function $V$ that will lead us to make *globally optimal greedy decisions*!
Pretty cool! But wait... iterate for *all* states? Clearly that's utterly infeasible. And what should this function $V$ look like internally? A big table of all the millions of states? That can't be it... We will tackle these two problems in turn:

### Deep RL: approximating $V$ with a neural network

The simplest option for implementing a function $V$ is a big table in memory where for each $s$ a separate value is stored.
This setting is known as *tabular RL* ("tabular" is the adjectival form of "table").  While it gives many nice guarantees for convergence etc., it is not very practical for real-world AI settings.  The issue is that most problems have infeasibly large state spaces: not only can we not fit these giant tables in our computer memory, but we also lose out by not *sharing information* between related states.

That is where "Deep" comes into Reinforcement Learning. Instead of having discrete entries for states $s$, let $V$ be a neural network that outputs a scalar given some vector representation of the state $s$.

For our purposes this means that the agent has to encode its state (or really the belief state, since we are in a POMDP, but we will ignore that as a technicality in the sequel) into a vector somehow. Well, we are in luck: we can just use the hidden state of the pretrained language model that we are already using.
Given this vector we could build an arbitrarily complicated neural network that outputs a scalar -- we will try for a simple linear regressor and some feedforward networks.

This of course also means that the assignment given in the equation above is meaningless: we cannot just "assign" an output value to a neural network. We will instead minimize the squared distance between the old estimate for $V(s)$ and the one given by the right-hand side of the Bellman equation.

In [None]:
class ValueFunctionApproximator(torch.nn.Module):
    def __init__(self, h_dims):
        super().__init__()
        # The main network
        self.linear1 = torch.nn.Linear(h_dims + 2, 16)
        self.linear2 = torch.nn.Linear(2 * 16 + 1, 1)
        # The more stable, since only slowly and indirectly updated target network
        self.linear1_target = torch.nn.Linear(h_dims + 2, 16)
        self.linear2_target = torch.nn.Linear(2 * 16 + 1, 1)

    def update_target_network(self):
        """
        Set the target network weights to a new moving average.
        """
        self.linear1_target.weight.data = 0.2 * self.linear1.weight.data + 0.8 * self.linear1_target.weight.data
        self.linear2_target.weight.data = 0.2 * self.linear2.weight.data + 0.8 * self.linear2_target.weight.data
        self.linear1_target.bias.data = 0.2 * self.linear1.bias.data + 0.8 * self.linear1_target.bias.data
        self.linear2_target.bias.data = 0.2 * self.linear2.bias.data + 0.8 * self.linear2_target.bias.data

    def forward(self, hcs, nchars_left, target=False):
        """
        Using PyTorch syntax, we define `forward()` to give us the scalar from
        the state representation we feed in: the LM hidden state and the number
        of characters left.
        The `target` parameter tells us whether to use the (moving average)
        target network weights.
        """
        if nchars_left == 0:
            return torch.tensor(0.0)
        lmrep = hcs[-1][0][0]
        n = float(nchars_left)
        inp = torch.cat([lmrep, torch.tensor([n, math.log(n)])])
        hid = (self.linear1_target if target else self.linear1)(inp).tanh()
        inp = torch.cat([hid, hid / n, torch.tensor([n])])
        return (self.linear2_target if target else self.linear2)(inp).squeeze()

In [None]:
# This is how we would use it:
hcs, nchars_left = LM.hcs_from_context("sequence␣modeling␣is␣"), 8
V = ValueFunctionApproximator(LM.lstm_layers[-1].hidden_size)
V(hcs, nchars_left)

We slipped in another trick here.  We defined two versions of our neural network for predicting values.  The main network has layers `linear1` and `linear2`,  but we also have another network with the same topology, whose layers are called `linear1_target` and `linear2_target`. This *target network* will be updated more slowly than the main network, so that it doesn't oscillate during training.  Its weights will be a *moving average* of the weights of the main network.  The trick is to use this target network in place of the main network when we compute the right-hand side of the Bellman equation.  This can improve convergence rates.  Note that no gradients flow into the target network -- its weights will only change by manually averaging past weights of the main network.

In [None]:
# We will need the expected return for the Bellman equation and the decisions:
def expected_returns(*, prefix_hcs, nchars_left, value_function, agent_lm, gamma=0.99, target=False):
    """
    `prefix_hcs` and `nchars_left` encode our state and will be used to find
    actions and possible environment feedback, according to the `agent_lm` LM.
    Returns the proposals (a list of strings of length 1..nchars_left) and a 1-D
    tensor containing the expected returns for all these action proposals
    according to the `value_function`.
    """
    ### STUDENTS START
    raise NotImplementedError()  # REPLACE ME
    ### STUDENTS END


In [None]:
# Given such returns, the loss we want to minimize is easy to define as the
# squared distance between the current V estimate and the right-hand side of the
# Bellman equation, using `expected_returns`.
def value_function_loss_for_state(*, prefix_hcs, nchars_left, value_function, agent_lm):
    # The old value of the value function
    lhs = value_function(prefix_hcs, nchars_left)
    # Consult our model to get the next states and their values
    _, returns = expected_returns(
        prefix_hcs=prefix_hcs,
        nchars_left=nchars_left,
        value_function=value_function,
        agent_lm=agent_lm,
        target=True  # this is where we use the target network!
    )
    # Choose according to the argmax -- not our policy! Because that is what we
    # want in the end! The corresponding expectated return is:
    rhs = torch.max(returns)
    # Return the squared distance between that and the current estimate.
    # Note that we are NOT doing "residual gradient learning" here (this would
    # mean also using the gradients of the right-hand side to update), because
    # this will make thing often optimize towards the wrong thing. We want to
    # update the "past" using the "future". So, we detach:
    return torch.nn.functional.mse_loss(lhs, rhs.detach())

In [None]:
# That is how we would use the function:
value_function_loss_for_state(prefix_hcs=hcs, nchars_left=nchars_left, value_function=V, agent_lm=LM)

### Learning from rollouts

The next question is: which states are we updating our value function (approximation) with?
Since we can only interact with the true environment one action and response at a time, it is impossible to just reach any state we would want to (and even if that were not true, enumerating all states in our giant state space would still be absolutely infeasible).

The answer here is that we will *roll out* some *exploration policy*. (So far, we have always used an `argmax` policy: choose the action that will get you the most reward -- but we will change that a little below and actually inject some randomness. Just wait a second.)

Rolling out a policy in the environment, we obtain a *sample trajectory* that performs the task from start to end. Then, given this trajectory, we update $V(s)$ for all states that are on this trajectory.

Note that in the update equation given above it looks a little like we are "trying out all the actions" -- but, like with the GreedyAgent, we are not testing all actions against the *true* environment, but only against the *model* of the environment that we have.

In [None]:
class ValueFunctionExpectedReturnAgent(TextProposalAgent):
    def __init__(self, lm, prefix, nchars_left, value_function, exploration_policy):
        """
        `exploration_policy` is a function that given a tensor of returns, yields the index
        of the action to take (usually, the best action).
        """
        self.lm = lm
        self.prefix = prefix
        self.prefix_hcs = lm.hcs_from_context(prefix)
        self.nchars_left = nchars_left
        self.value_function = value_function
        self.exploration_policy = exploration_policy
        self.visited_cache = []

    def receive_response(self, reward, response):
        self.prefix += response
        self.prefix_hcs = self.lm.hcs_from_context(response, hcs=self.prefix_hcs)
        self.nchars_left -= len(response)

    def decision(self):
        # Save visited state information
        hcs = tuple((hc[0].detach(), hc[1].detach()) for hc in self.prefix_hcs)
        self.visited_cache.append(
            {
                "prefix_hcs": hcs,
                "nchars_left": self.nchars_left
            }
        )
        # Now pick the one with the highest expected return!
        proposals, returns = expected_returns(
            prefix_hcs=self.prefix_hcs,
            nchars_left=self.nchars_left,
            value_function=self.value_function,
            agent_lm=self.lm
        )
        return proposals[self.exploration_policy(returns)]

In [None]:
# Try it again -- this time, everything is super-fast since we essentially make greedy decisions
for maxlen in range(1, 10):
    user = TypistState(string=sample[:12+maxlen], start_index=12)
    vfa = ValueFunctionApproximator(XOLM.lstm_layers[-1].hidden_size)
    %time user.evaluate_agent(agentclass=ValueFunctionExpectedReturnAgent, lm=XOLM, value_function=vfa, exploration_policy=lambda t: torch.argmax(t))

We end up with a roughly linear dependence (R²=0.91 on our measurements up to length 100), as expected. Nearly there!

### We need to force exploration

The way our update is defined, we will only ever visit and update the states that our value function tells us are good. So if there is a state that is good, but our current value function doesn't know that -- it may never find out!
We therefore need to force our policy to *explore* states, even if they aren't the best according to the value function.

Specifically, we will use $\epsilon$-greedy sampling: with probability $1-\epsilon$, take the best action, and with probability $\epsilon$, sample uniformly from all *other* actions ($\epsilon$ is usually set to something like $.05$).

In [None]:
def eps_greedy_policy(returns, eps=.05):
    ### STUDENTS START
    raise NotImplementedError()  # REPLACE ME
    ### STUDENTS END


With this, we are ready to write the entire training loop:

In [None]:
def train(*, agentclass, params_to_optimize, agent_lm, exploration_policy, loss_function, target_network_updater, sentences, log_interval, cache_gradients, **kwargs):
    random.seed(0)
    sum_reward, sum_loss = 0, 0
    # Only update the desired parameters (i.e., the main network, not the target network)
    optimizer = torch.optim.Adam(params_to_optimize, lr=1.0)
    optimizer.zero_grad()
    for i, sentence in enumerate(sentences):
        # Run in environment
        env = TypistState(string=sentence, start_index=0)
        reward, agent = env.evaluate_agent(
            agentclass=agentclass,
            lm=agent_lm,
            return_agent=True,
            exploration_policy=exploration_policy,
            **kwargs
        )
        # Construct loss for function approximator
        loss = torch.sum(
            torch.stack(
                [loss_function(agent_lm=agent_lm, **d, **kwargs) for d in agent.visited_cache]
            )
        )
        # Output
        sum_reward += reward
        sum_loss += loss.item()
        if (i + 1) % log_interval == 0:
            print("Avg reward (using exploration policy):", sum_reward / log_interval, "Avg loss:", sum_loss / log_interval)
            print("Mean abs weights:", [p.abs().mean().item() for p in params_to_optimize])
            # print("VFA example:", ' '.join([f"{value_function(XOLM.hcs_from_context('ooxo␣xxoxoxxxo␣xxoxox␣xxo␣oxoxxoxo␣x'), ncl).item():.3f}" for ncl in range(1, 10)]))
            sum_reward, sum_loss = 0, 0
        # Optimize/update
        loss.backward()
        # Now apply the gradients to the target network after n iterations
        if (i + 1) % cache_gradients == 0:
            optimizer.step()
            target_network_updater()
            optimizer.zero_grad()

In [None]:
def train_vf(agent_lm, sentences, cache_gradients=2):
    """
    You can increase `cache_gradients` for increased stability...
    ...at the cost of slower convergence!
    """
    vfa = ValueFunctionApproximator(agent_lm.lstm_layers[-1].hidden_size)
    train(
        agentclass=ValueFunctionExpectedReturnAgent,
        params_to_optimize=list(vfa.linear1.parameters()) + list(vfa.linear2.parameters()),
        agent_lm=agent_lm,
        exploration_policy=eps_greedy_policy,
        value_function=vfa,
        loss_function=value_function_loss_for_state,
        target_network_updater=vfa.update_target_network,
        sentences=sentences,
        log_interval=20,
        cache_gradients=2
    )
    return vfa

First let's try to overfit to some finite set of sentences (this is gonna take a few minutes, feel free to reduce the number of times we train on that sentence -- it is set to 500 below).

In [None]:
testsentence = '␣xxo␣x␣ooxo␣xxxooxx␣oxoxx␣xoxo'
vfa_overfit = train_vf(XOLM, [testsentence] * 200)

In [None]:
print("Greedy gets:", TypistState(string=testsentence, start_index=0).evaluate_agent(agentclass=FastGreedyExpectedImmediateRewardAgent, lm=XOLM))

print(
    "After training, we get this reward using our argmax policy:",
    TypistState(string=testsentence, start_index=0).evaluate_agent(
        agentclass=ValueFunctionExpectedReturnAgent,
        lm=XOLM,
        value_function=vfa_overfit,
        exploration_policy=lambda t: torch.argmax(t)
    )
)

Nice (although we brutally overfit)! Let's do a qualitative check whether the value function learned what we wanted it to learn:

In [None]:
intermediate_sentence = '␣xxo␣x␣ooxo␣xxxooxx'

# The more characters we think are left, the higher we think the return will be:
print(' '.join([f"{vfa_overfit(XOLM.hcs_from_context(intermediate_sentence), ncl).item():.3f}" for ncl in range(1, 10)]))

# This is what our Bellman updates look like (with a well-trained model, both sides should be roughly equal):
prefix_hcs = XOLM.hcs_from_context(intermediate_sentence)
for ncl in range(1, 10):
    _, returns = expected_returns(prefix_hcs=prefix_hcs, nchars_left=ncl, value_function=vfa_overfit, agent_lm=XOLM)
    print(vfa_overfit(prefix_hcs, ncl).item(), "is Bellman-updated to come closer to", returns[torch.argmax(returns)].item())

# This is how the input features are used:
print(vfa_overfit.linear1.weight[0])

Before peeking ahead, should the value function *increase* or *decrease* the further we progress through the sentence/trajectory? Why? $\color{red}{\text{FILL IN}}$

In [None]:
# Let's check for our example sentence, what the value function at every position is:
# (we should see it go down to 0)
for i in range(len(testsentence)):
    hcs = XOLM.hcs_from_context(testsentence[:i])
    ncl = len(testsentence)-i
    loss = value_function_loss_for_state(
        prefix_hcs=hcs,
        nchars_left=ncl,
        value_function=vfa_overfit,
        agent_lm=XOLM
    )
    print("After", i, "characters, we have V =", vfa_overfit(hcs, ncl).item(), "incurring a loss of", loss.item())

Well here's a fun oddity. Most of the states have quite decent value function judgements it seems, but the negative numbers at the end look very wrong: and indeed the second-to-final state incurs a large loss. Why didn't we learn something better?

The answer is simple: we almost never visited this state, so we never updated our value function here!  This sort of thing often happens in RL (and we're not going to try to fix it).  We can see that that is true when plotting the *stationary distribution* over states: run the agent a bunch of times and figure our which states were actually visited:

In [None]:
from collections import Counter

stationary = Counter()

for _ in range(50):
    env = TypistState(string=testsentence, start_index=0)
    _, agent = env.evaluate_agent(
        agentclass=ValueFunctionExpectedReturnAgent,
        lm=XOLM,
        return_agent=True,
        value_function=vfa_overfit,
        exploration_policy=eps_greedy_policy
    )
    for d in agent.visited_cache:
        stationary[d["nchars_left"]] += 1

for i in range(len(testsentence)):
    print(f"The state after {i:2} characters:", '#' * stationary[len(testsentence) - i])

We can see that there are some states that were visited very often, and some that are visited very rarely. Why is that?

$\color{red}{\text{FILL IN}}$

Alright, enough with this simple overfitting. It's time to train our agent on the real data distribution: every time it enters the environment, it has to guess a new sentence! Can we still do well?

In [None]:
def sample_generator(lm, temperature=0.2):
    while True:
        yield ''.join(itertools.islice(lm.greedy_sample(temperature=temperature), 30))

We will only train on 500 sentences, that should make us perform *okay* -- but feel free to train for longer to see just how far this agent gets!

In [None]:
# Now train on some sentences from that generator
vfa_all = train_vf(XOLM, itertools.islice(sample_generator(XOLM), 500), cache_gradients=5)

Looks like it trained well! Now let's compare this new great agent and our old GreedyDecisionAgent on a multitude of sentences as well:

In [None]:
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

def greedy_vs_vf_generator():
    for sentence in sample_generator(XOLM):
        reward_greedy = TypistState(string=sentence, start_index=0).evaluate_agent(
            agentclass=FastGreedyExpectedImmediateRewardAgent,
            lm=XOLM   
        )
        reward_vf = TypistState(string=sentence, start_index=0).evaluate_agent(
            agentclass=ValueFunctionExpectedReturnAgent,
            lm=XOLM,
            value_function=vfa_all,
            exploration_policy=lambda t: torch.argmax(t)
        )
        if abs(reward_vf - reward_greedy) > 4:
            print("Remarkable difference: on", sentence, "we have greedy:", reward_greedy, "value functions:", reward_vf)
        yield (reward_greedy, reward_vf)

def plot_agent_difference(reward_pair_generator):
    list1, list2, list_diff = [], [], []
    for i, (reward1, reward2) in enumerate(itertools.islice(reward_pair_generator, 100)):
        if (i + 1) % 50 == 0:
            c1 = Counter(list1)
            c2 = Counter(list2)
            plt.bar(range(25), [c1[i] for i in range(25)], alpha=0.3)
            plt.bar(range(25), [c2[i] for i in range(25)], alpha=0.3)
            plt.legend(["agent 1", "agent 2"])
            plt.show()
            cd = Counter(list_diff)
            plt.bar([i - 10 for i in range(21)], [cd[i-10] for i in range(21)])
            plt.legend(["2 - 1"])
            plt.show()
        
        list1.append(reward1)
        list2.append(reward2)
        list_diff.append(reward2 - reward1)
    print("\nsums:", sum(list1), "vs.", sum(list2))

plot_agent_difference(greedy_vs_vf_generator())

Looks like the value functions help just a little bit over the greedy agent! In the final section, we will then ask the natural question: if we improve that little over our simple thing, what is all that hype about? But before, some small notes on how you could make this VF training better and faster.

### Possible improvements

1. **Update on minibatches** As with gradient descent, relying on single samples means instability and slow convergence -- so usually we would sample many trajectories before performing an update on our value function.
2. **Unroll the Bellman equation** Instead of taking the value of the successor states as given by $V$ at face value, we could instead try to compute them too as the expectation over these states' successors (i.e., brute-force lookahead for another step). This yield more precise results at the cost of more computation.
3. **Sample for the Bellman equation** We have been exploiting the fact that we have a model $p$ to take an expectation over next states that our action could get to.  This expectation may be slow to compute exactly, especially when we unroll the equation to multiple timesteps as above.  So we can instead approximate the expectation by sampling.  This leads to an interesting tradeoff between unrolling depth (which helps exactness) and a smaller number of samples for each decision on this unrolled depth (if we want to keep computation constant).

## This sucks. Why would you do RL?

What we found out here is that our greedy agent was able to take very good actions -- even with our complicated value function approximation of the entire game tree, we had a hard time doing better. The reason is simple: we supplied our agent with both knowledge of the reward function and the language model that the environment actually used!
We basically handed our agent a complete (and correct!) model of the world -- a model in which, truly, search was all that was ever needed. So, really, instead of learning something about the environment (as one usually tried in RL), we just learned to search our existing perfect model of it. The fact that there are real trajectories with real rewards didn't matter to us at all.

This is a rather uncommon scenario in RL.
We've already hinted at the fact that things wouldn't go so well if we didn't give the agent a reward function that drips out sensible feedback on each iteration -- and it should be easy to see that if the language model that the agent has and the language model that the environment uses to choose $\mathbf{w}$ differ, our greedy agent would not get away that easily.

So, in this last section we will take away this perfect model and thus the capability of the agent to *reason* about what rewards it will get for any action. It will truly have nothing but the rewards it got from its current rollout to go on. Let's see how far this will get us. Can we still do better than the random agent (which is what the greedy agent essentially would fall back to in this case)?

### Learning $Q$ without modeling $p$

Say that we don't know $p$.  Where does our approach break if we no longer have this explicit environment model? We can identify two places:

1. During the Bellman update: we need to find out what states $s'$ an action *could get us into* to not only evaluate whether that is the one we would take but also to then look up the value of that state $s'$ in $V$. If we can't make these inferences anymore, what signal can we rely on? Only the transitions $s \overset{a}{\rightarrow} s'$ that we actually observed in our roll-out phase!
*Note that had we stuck to value function learning, our $\epsilon$-greedy exploration policy would then be "baked into" the estimates for $V$! That is of course undesirable, since we want to use $V$ with our optimal `argmax` policy later on -- we want the value function to tell us what expected return we would get with that optimal policy, not the suboptimal $\epsilon$-greedy policy. The method we will introduce, Q-learning, will solve that problem nicely: with Q-learning, you can take *any* policy for exploration -- as long as it is able to reach every state, you will get the same results (in the tabular case)! (Of course, this is not to say that learning value functions using rollout trajectories only is impossible. One classic algorithm, called [TD-Lambda](https://en.wikipedia.org/wiki/Temporal_difference_learning#TD-Lambda), has sucessfully been used back in 1992 to [solve Backgammon in an early win for RL in 1992](https://en.wikipedia.org/wiki/TD-Gammon).)*

2. Both in the Bellman update and during the actual exploration of course, we still need to decide what action to take! If we would stick to value functions, that would mean learning some neural net for both $V$ and some policy $\pi$, that, given $s$, tells us which $a$ to take. Why not fuse these two into one? Observe:

We will learn a *$Q$-function* that given a state *and action* tells us the expected return, i.e., instead of querying $V(s)$ for a state $s$ we will now query $Q(s, a)$ for some state-action-pair $(s, a)$.
If, again, we would have this function, a policy is trivial to derive:
$\pi(s) = \mathrm{argmax}_{a \in \mathcal{A}} Q(s, a)$ -- no need to query the perfect environment model anymore.

Most things stay the same as before, though. We again approximate $Q$ with a neural network, we will still use the hidden states of our pretrained LM as a convenient feature representation for this function approximator, and for the eventual proposal, we will still extract the 1-best strings for the proposed length from the pretrained language model (to keep things simple for this assignment).

So let's go through the only thing that really changes: the Bellman update for this $Q$-function. We already stated that we are in a situation where instead of being able to expand the game tree, we only observe one trajectory:

<img width="100%" src="https://sjmielke.com/tmp/state_to_state_without_model.png" />

The image hints at the solution: for a given state $s$ we can always define $V(s)$ using the $Q$-function: the value of the state is the maximum value you can get from taking any action $a$ in this state, where the value of taking $a$ in $s$ is precisely what $Q$ gives us. As the new Bellman update we obtain a new estimate for $Q(s, a_1)$ (because we took $a_1$):

$$
    Q(s, a_1) = r_{1,3} + \gamma V(s_{1,3}) = r_{1,3} + \gamma \max_{a' \in \{a_{1,3,1}, a_{1,3,2}\}} Q(s_{1,3}, a')
$$

Note that even though the equation refers to $s_{1,3}$, it should really say $s_{1,?}$ as we have no idea what other states (here, $s_{1,1}$ and $s_{1,2}$) we could have landed in! So, more simply, if we sampled $s \overset{a}{\rightarrow} s'$ giving reward $r$ during exploration, this is the corresponding update:

$$
    Q(s, a) = r + \gamma \max_{a' \,\text{from}\, s'} Q(s', a')
$$

All right! Let's implement it!

In [None]:
class QFunctionApproximator(torch.nn.Module):
    def __init__(self, h_dims):
        super().__init__()
        # This time we will not only feed the prefix hidden state but also
        # the hidden state of the pretrained LM after reading in our proposal
        # -- this will serve as a representation of the action a -- as well as
        # the length and log-length of the proposed action.
        self.linear1 = torch.nn.Linear(h_dims + 2 + h_dims + 2, 16)
        self.linear2 = torch.nn.Linear(2 * 16 + 1, 1)
        # We will again use target networks for stability.
        self.linear1_target = torch.nn.Linear(h_dims + 2 + h_dims + 2, 16)
        self.linear2_target = torch.nn.Linear(2 * 16 + 1, 1)

    def update_target_network(self):
        """
        Set the target network weights to a new moving average.
        """
        self.linear1_target.weight.data = 0.2 * self.linear1.weight.data + 0.8 * self.linear1_target.weight.data
        self.linear2_target.weight.data = 0.2 * self.linear2.weight.data + 0.8 * self.linear2_target.weight.data
        self.linear1_target.bias.data = 0.2 * self.linear1.bias.data + 0.8 * self.linear1_target.bias.data
        self.linear2_target.bias.data = 0.2 * self.linear2.bias.data + 0.8 * self.linear2_target.bias.data

    def forward(self, *, hcs_state, nchars_left, hcs_proposal, proposal_length, target=False):
        """
        `hcs_state`: the state of the pretrained LM after reading the prefix
        `hcs_proposal`: ~ after also reading the proposed action
        `proposal_length`: length of the proposed extension
        `nchars_left` and `target` same as before.
        """
        if nchars_left == 0:
            return torch.tensor(0.0)
        n = float(nchars_left)
        l = float(proposal_length)
        inp = torch.cat(
            [
                hcs_state[-1][0][0],
                torch.tensor([n, math.log(n)]),
                hcs_proposal[-1][0][0],
                torch.tensor([l, math.log(l)])
            ]
        )
        hid = (self.linear1_target if target else self.linear1)(inp).tanh()
        inp = torch.cat([hid, hid / n, torch.tensor([n])])
        return (self.linear2_target if target else self.linear2)(inp).squeeze()


In [None]:
# This is how we would use it:
prefix_hcs, nchars_left = LM.hcs_from_context("sequence␣modeling␣is␣"), 8
proposal = "the␣best"
proposal_hcs = LM.hcs_from_context("the␣best", hcs=prefix_hcs)
Q = QFunctionApproximator(LM.lstm_layers[-1].hidden_size)
Q(
    hcs_state=prefix_hcs,
    nchars_left=nchars_left,
    hcs_proposal=proposal_hcs,
    proposal_length=len(proposal)
)

In [None]:
def q_function_loss_for_state_action_reward_state(*, hcs_state1, nchars_left1, reward, hcs_state2, nchars_left2, q_function, agent_lm, gamma=0.99):
    """
    Both `hcs_state` and `nchars_left` have versions 1 and 2 for the two states
    we need here. Since we chose to encode the action by its final LM hidden
    state and its length, there is no need to feed in any more information about
    it, as the former is simple `hcs_state2` and the latter can be calculated as
    `nchars_left1 - nchars_left2`.
    We do, however, need to pass in the immediate reward that we received.
    Note that the `agent_lm` is only used to calculate all possible actions from
    state2, not to do model-based calculations!
    """
    ### STUDENTS START
    raise NotImplementedError()  # REPLACE ME
    ### STUDENTS END


In [None]:
# That is how we would use the function:
q_function_loss_for_state_action_reward_state(
    hcs_state1=prefix_hcs,
    hcs_state2=proposal_hcs,
    nchars_left1=nchars_left,
    nchars_left2=nchars_left - len(proposal),
    reward=8,
    q_function=Q,
    agent_lm=LM
)

In [None]:
class QFunctionExpectedReturnAgent(TextProposalAgent):
    def __init__(self, lm, prefix, nchars_left, q_function, exploration_policy):
        self.lm = lm
        self.prefix = prefix
        self.prefix_hcs = lm.hcs_from_context(prefix)
        self.nchars_left = nchars_left
        self.q_function = q_function
        self.exploration_policy = exploration_policy
        self.visited_cache = []

    def receive_response(self, reward, response):
        """
        This time we can only append a completed dict to the `visisted_cache` list!
        """
        # Assemble tuple
        hcs_state1 = self.prefix_hcs
        hcs_state2 = self.lm.hcs_from_context(response, hcs=self.prefix_hcs)
        nchars_left1 = self.nchars_left
        nchars_left2 = self.nchars_left - len(response)
        # Append
        self.visited_cache.append(
            {
                "hcs_state1": hcs_state1,
                "hcs_state2": hcs_state2,
                "nchars_left1": nchars_left1,
                "nchars_left2": nchars_left2,
                "reward": reward
            }
        )
        # Update internal state
        self.prefix += response
        self.prefix_hcs = hcs_state2
        self.nchars_left = nchars_left2

    def decision(self):
        """
        No need to append to our cache here (unlike the VF agent), but do
        remember to use the `self.exploration_policy`!
        """
        hcs = tuple((hc[0].detach(), hc[1].detach()) for hc in self.prefix_hcs)
        ### STUDENTS START
        raise NotImplementedError()  # REPLACE ME
        ### STUDENTS END


In [None]:
def train_qf(agent_lm, sentences, cache_gradients=10):
    qf = QFunctionApproximator(agent_lm.lstm_layers[-1].hidden_size)
    train(
        agentclass=QFunctionExpectedReturnAgent,
        params_to_optimize=list(qf.linear1.parameters()) + list(qf.linear2.parameters()),
        agent_lm=agent_lm,
        exploration_policy=eps_greedy_policy,
        q_function=qf,
        loss_function=q_function_loss_for_state_action_reward_state,
        target_network_updater=qf.update_target_network,
        sentences=sentences,
        log_interval=50,
        cache_gradients=cache_gradients
    )
    return qf

Again, test if we can overfit this one sentence. This time, we will probably need a lot more samples to get a reasonable estimate because we only learn from our trajectories. In real settings, "more samples" means millions and billions -- our code in this notebook already takes a long time for thousands. But maybe the simplicity of our problem makes this still feasible?

In [None]:
testsentence = '␣xxo␣x␣ooxo␣xxxooxx␣oxoxx␣xoxo'
qf_overfit = train_qf(XOLM, [testsentence] * 1000)

In [None]:
# Let's check for our example sentence, what the maximum Q value at every position is:
# (we should see it go down to 0)
for i in range(len(testsentence)):
    hcs = XOLM.hcs_from_context(testsentence[:i])
    ncl = len(testsentence)-i
    hyps = XOLM.best_strings_from_hcs(hcs=hcs, max_length=ncl)
    maxq = max(
        [
            qf_overfit(
                hcs_state=hcs,
                nchars_left=ncl,
                hcs_proposal=end_hcs,
                proposal_length=action_length
            ).item()
            for action_length, (_, _, end_hcs) in enumerate(hyps)
            if action_length > 0
        ]
    )
    print("After", i, "characters, we have a maximum Q of", maxq)

In [None]:
print("Random gets:", TypistState(string=testsentence, start_index=0).evaluate_agent(agentclass=FastRandomLengthAgent, lm=XOLM))

print(
    "After training, we get this reward using our argmax policy:",
    TypistState(string=testsentence, start_index=0).evaluate_agent(
        agentclass=QFunctionExpectedReturnAgent,
        lm=XOLM,
        q_function=qf_overfit,
        exploration_policy=lambda t: torch.argmax(t)
    )
)

You might have gotten lucky and learned something or not. If not, you can restart and try again or keep reading on. Deep reinforcement learning is very unstable and unpredictable!

Also for the final test, we can again try to train on the true distribution:

In [None]:
qf_all = train_qf(XOLM, itertools.islice(sample_generator(XOLM), 2000))

Time for the final test: how well does it perform against the random agent overall?

In [None]:
def random_vs_q_generator():
    for sentence in sample_generator(XOLM):
        reward_random = TypistState(string=sentence, start_index=0).evaluate_agent(
            agentclass=FastRandomLengthAgent,
            lm=XOLM
        )
        reward_qf = TypistState(string=sentence, start_index=0).evaluate_agent(
            agentclass=QFunctionExpectedReturnAgent,
            lm=XOLM,
            q_function=qf_all,
            exploration_policy=lambda t: torch.argmax(t)
        )
        if abs(reward_qf - reward_random) > 4:
            print("Remarkable difference: on", sentence,
                  "we have random:", reward_random, "vs. Q-functions:", reward_qf)
        yield (reward_random, reward_qf)

%matplotlib inline
plot_agent_difference(random_vs_q_generator())

You should see some very unconvincing numbers: the agent doesn't perform too well.
Why is that, you think? $\color{red}{\text{FILL IN}}$

*Note: we don't **know** the answer to that question for sure! But you can make some educated guesses -- or find a bug and make it work for extra credit? ;)*

# Congrats! You've won the game.

That's it for this assignment! If you want, play around with the reinforcement agent a bit more: try other policies and see how this has a bigger effect on $V$ than on $Q$, try to use more complicated networks, try to speed things up by batching... But probably, if you are to attempt to use RL "in the wild", you will -- like with most things in this class -- choose to write more task-specific and optimized code, but hopefully you have gotten a little glimpse of how reinforcement learning isn't all about robots moving -- but can be instructive to think about in other settings, too!