# Homework 4 — A slightly different introduction to Deep Reinforcement Learning

In this homework we will tackle an entirely new problem that will allow us to show off some reinforcement learning -- a setting where our actions *alternate* with responses from the environment, and the two influence each other.

In [None]:
import gzip
import itertools
import json
import math
import numpy as np
import os
import random
import time
import urllib
import urllib.request

from IPython.core.display import HTML

import torch

from seq2class_previous_homeworks import StatefulTaskSetting, IncrementalScoringModel, BeamDecisionAgent, draw_tree


## The setting: predictive keyboards

You are likely familiar with modern predictive keyboards from your smartphone.  Given the text you've typed so far, the keyboard proposes a next word; the user can accept or ignore this proposal. You should have seen this as a contemporary application of *language models* in NLP in Fall, and indeed you already built an FST-powered completion and prediction tool in that class.

To keep things simple, we will assume that we *already have* some language model.  Furthermore, it's a language model over *characters*.  This gives us a conveniently small vocabulary.  It also allows us to ignore word boundaries: we will treat the space symbol as just another character.

As we are not making any use of words as units, we will no longer be proposing the next *word*.  Proposing just the *single next character* would probably not speed up typing enough, so our keyboard will propose a sequence of *several characters*.  The user can decide whether to accept this proposed substring or ignore it.  The major question will be *how many* characters to propose at once.

## Warm-Up: predicting strings from a language model trie

Here is an example: assume the user typed "`the defor`," what characters should we follow up with? Since we assume we already have a language model, let's assume this is what the trie that starts here looks like (omitting all the low-probability options):

In [None]:
display(HTML(draw_tree({'e/0.4': {'s/1.0': {'t/0.7': {'a/1.0': {}}}}, 'm/0.6': {'a/0.5': {'t/1.0': {'i/1.0': {}}}, 'e/0.5': {'d/1.0': {'␣/1.0': {}}}}})))

So the user might have wanted to type `deformed`, `deformation`, or `deforestation`. Given the very simple probabilities on the trie above, we can see what the most probable choice is if we want to propose a substring of 1, 2, 3, or 4 next characters:

length | prediction
---|---
1 | `m`
2 | `es`
3 | `est`
4 | `med_` / `mati`  (tie)

The optimal longer strings do not all start with `m`.  Thus, we see that the best prediction of length 3, for example, cannot in general be obtained by greedily choosing the best arc at each trie vertex, but will require some *planning ahead*.  (This shouldn't be surprising.)

How do we choose the length?  The user can only accept or ignore the entire proposed substring. Proposing a longer substring increases our *reward* (i.e., time saved over manual typing) if the user accepts -- but it also increases the chance that the proposal is wrong, in which case the user will be forced to ignore it and type the next character themselves.
That is the tradeoff that we will have to navigate in this assignment.
For now, we will assume that we do indeed have access to the full trie of the language model including all probabilities -- but at the end of the assignment we will see what we can still do even if we don't have this option to plan ahead.

We will give you a very simple character-level English LM over the 27 characters `a`-`z` and `␣` (the space character).  For simplicity, we do not include an EOS symbol.  Much of the code should be familiar to you from HW3.  Rather than make you run the training code (which takes about 20 minutes per epoch on 100MB of text, using a GPU), we will hand you the weights from our own training run.  

**Notes on the English language model:**

Our trained model achieves about 1.5 bits per character of cross-entropy (vs. 1.08 bits for the state-of-the-art Transformer-XL model).  Our model is weak because it uses relatively few parameters, and we only trained for 1 epoch, without any kind of regularization.

You'll see that the training procedure is a little different from last homework's models: we group the training examples into batches, which allows us to parallelize SGD and improves its stability.  

As in HW3, we made the model more transparent by avoiding PyTorch's `LSTM` class (which trains an entire batch of sequences on the GPU) in favor of the `LSTMCell` class (which trains only a single time step for all the sequences).  This slows down training because we have to call the GPU repeatedly, once per time step, but it makes it somewhat more convenient to examine the trie probabilities via the `LSTMCell` class as we explore reinforcement learning.  Don't do it this way in a real-world setting where speed is important.


In [None]:
class LanguageModel(torch.nn.Module):

    def __init__(self, vocab, layers=3, hidden_size=512, embedding_size=16):
        super().__init__()
        
        # a simple character integerizer
        self.idx2char = vocab
        self.char2idx = {c: i for i, c in enumerate(vocab)}
        
        # character embedding module
        self.embedding = torch.nn.Embedding(len(vocab), embedding_size)
        
        # LSTM cells to run over character embeddings
        # Note that it may be easier to just use an torch.nn.LSTM instead of
        # manually gluing together cells like this.
        self.lstm_layers = [torch.nn.LSTMCell(embedding_size, hidden_size)] \
            + [torch.nn.LSTMCell(hidden_size, hidden_size) for _ in range(layers - 2)] \
            + [torch.nn.LSTMCell(hidden_size, embedding_size)]
        for i, layer in enumerate(self.lstm_layers):
            self.add_module(f"lstmlayer{i}", layer)

    def _hcs_from_cidx(self, hcs, c_idx):
        c_emb = self.embedding(c_idx)
        nhcs = [(c_emb, None)]
        for hc, layer in zip(hcs, self.lstm_layers):
            nhcs.append(layer(nhcs[-1][0], hc))
        return nhcs[1:]
        
    def hcs_from_context(self, context_string, hcs=None, get_all_hcs=False):
        """
        Sets up the language model hidden state given a string prefix
        and, optionally, the hidden state before reading in this prefix.
        
        If context_string is a list, this performs batched computations,
        interpreting context_string batch-first, timestep-second.
        """
        if isinstance(context_string, str):
            batchsize = 1
        else:
            # "Transpose" to get it timestep-first, batch-second
            batchsize = len(context_string)
            context_string = [tuple(cs) for cs in itertools.zip_longest(*context_string)]
        if hcs is None:
            # initalize the hidden state of the LSTM
            cs = [torch.zeros(batchsize, layer.hidden_size, device = self.embedding.weight.device) for layer in self.lstm_layers]
            hcs = [(torch.tanh(c), c) for c in cs]
        if get_all_hcs: all_hcs = [hcs]
        # iterate over the string
        for c in context_string:
            c_idxs = torch.tensor([self.char2idx[c] for c in tuple(c)], device = self.embedding.weight.device)
            hcs = self._hcs_from_cidx(hcs, c_idxs)
            if get_all_hcs: all_hcs.append(hcs)
        return all_hcs if get_all_hcs else hcs

    def next_options(self, hcs = None, logprobs = False):
        """
        Given the hidden state tuple of the LM, it returns a dictionary
        mapping potential next characters to their probabilities.
        """
        if hcs is None:
            hcs = self.hcs_from_context("")
        probs = (hcs[-1][0] @ self.embedding.weight.t()).log_softmax(dim=-1)
        if not logprobs:
            probs = probs.exp()
        l = [{c: p.item() for c, p in zip(self.idx2char, probs[b])}
             for b in range(probs.size(0))]
        # If no batching is used (i.e., batchsize == 1), return simple dict.
        if len(l) == 1:
            return l[0]
        else:
            return l
    
    def greedy_1_best(self, hcs = None):
        """
        Infinite iterator returning the greedy choices of next character.
        """
        with torch.no_grad():
            hcs = self.hcs_from_context("", hcs)  # initialize if needed
            while True:
                idx = torch.argmax(hcs[-1][0] @ self.embedding.weight.t())
                yield self.idx2char[idx]
                hcs = self._hcs_from_cidx(hcs, torch.tensor([idx], device = hcs[-1][0].device))

    def greedy_sample(self, temperature = 0.5, hcs = None):
        """
        Infinite iterator returning local samples of next character.
        """
        with torch.no_grad():
            hcs = self.hcs_from_context("", hcs)  # initialize if needed
            while True:
                weights = (hcs[-1][0] @ self.embedding.weight.t() / temperature).exp().squeeze(0)
                [char] = random.choices(self.idx2char, weights = weights)
                yield char
                hcs = self._hcs_from_cidx(hcs, torch.tensor([self.char2idx[char]], device = hcs[-1][0].device))
    
    def render_trie_from(self, hcs, depth=2, topk=-1, prob_gt=-1):
        assert topk < 0 or prob_gt < 0
        def trie_from(hcs, depth):
            if depth == 0:
                return {}
            else:
                expanded_cs = list(self.next_options(hcs).items())
                if topk > 0:
                    expanded_cs = sorted(expanded_cs, key=lambda x:-x[1])[:topk]
                if prob_gt > 0:
                    expanded_cs = [(c, p) for c, p in expanded_cs if p > prob_gt]
                return {f"{c}/{p:.2f}":
                        trie_from(self.hcs_from_context(c, hcs=hcs), depth - 1)
                        for c, p in sorted(expanded_cs, key=lambda x: x[0])}
        display(HTML(draw_tree(trie_from(hcs, depth))))

In [None]:
# This will train the language model (don't execute this, it would take too long!):

def train_lm(name, dataset, nlayers, nhid, embsize, BATCHSIZE, BPTTLENGTH):
    STUB = f"{name}_layers{nlayers}_hs{nhid}_emb{embsize}_bs{BATCHSIZE}_bptt{BPTTLENGTH}_adam1e-3_epochs1"
    start_time = time.time()
    with open(STUB+".log", 'wt') as logfile:
        print(STUB)
        print(STUB, file=logfile)
        vocab = "".join(sorted(list(set(dataset))))
        LM = LanguageModel(vocab, layers=nlayers, hidden_size=nhid, embedding_size=embsize).cuda()

        total_length = len(dataset) // BATCHSIZE
        batches = [
            [
                dataset[total_length * i + BPTTLENGTH * batch_nr
                      : total_length * i + BPTTLENGTH * (batch_nr + 1)]
                for i in range(BATCHSIZE)
            ]
            for batch_nr in range(total_length // BPTTLENGTH)
        ]

        hcs = LM.hcs_from_context([""] * BATCHSIZE)
        optimizer = torch.optim.Adam(LM.parameters(), lr=1e-3)

        nll_sum = 0
        for i_batch, batch in enumerate(batches):
            hcs = [(h.detach(), c.detach()) for (h, c) in hcs]
            nll = 0
            for i in range(BPTTLENGTH):
                # These chars at this timestep
                charidxs = torch.tensor([LM.char2idx[batch[r][i]] for r in range(BATCHSIZE)], device = LM.embedding.weight.device)
                # Predict
                lps = (hcs[-1][0] @ LM.embedding.weight.t()).log_softmax(dim=-1)
                nll += -lps.gather(dim=-1, index=charidxs.unsqueeze(-1)).sum()
                # Next timestep
                hcs = LM._hcs_from_cidx(hcs, charidxs)
            # Calculate gradients
            optimizer.zero_grad()
            nll.backward()
            # Debug output
            nll_sum += nll.detach().item()
            if i_batch % 50 == 0:
                bpc = (nll_sum / ((50 if i_batch > 0 else 1) * BATCHSIZE * BPTTLENGTH)) / math.log(2)
                nll_sum = 0
                greedy = ''.join(itertools.islice(LM.greedy_1_best(LM.hcs_from_context("")), 70))
                sample = ''.join(itertools.islice(LM.greedy_sample(LM.hcs_from_context("")), 70))
                print(f"{i_batch:5}/{len(batches)} {bpc:9.2f} -> {greedy} ~~ {sample}")
                print(f"{i_batch:5}/{len(batches)} {bpc:9.2f} -> {greedy} ~~ {sample}", file=logfile)
            # Apply gradients
            optimizer.step()

        torch.save(LM.state_dict(), STUB+".statedict.pt")
        print("Trained in", time.time() - start_time, "seconds.")
        print("Trained in", time.time() - start_time, "seconds.", file=logfile)
        start_time = time.time()
        _ = ''.join(itertools.islice(LanguageModel.greedy_sample(LM), 1000))
        print("Sampled 1000 chars in", time.time() - start_time, "seconds.")
        print("Sampled 1000 chars in", time.time() - start_time, "seconds.", file=logfile)

should_we_retrain_the_language_model = "no"

if should_we_retrain_the_language_model == "hell yeah":
    if not os.path.isfile("text8.zip"):
        import urllib
        urllib.request.urlretrieve("http://mattmahoney.net/dc/text8.zip", "text8.zip")
    if not os.path.isfile("text8"):
        import zipfile
        zip_ref = zipfile.ZipFile("text8.zip", 'r')
        zip_ref.extractall(".")
        zip_ref.close()

    with open("text8", 'r') as f:
        text8 = f.read()
    
    train_lm("text8", text8, nlayers = 2, nhid = 512, embsize = 16, BATCHSIZE = 80, BPTTLENGTH = 150)

Just to illustrate how this class is supposed to work, let's run it on our example above:

In [None]:
# Load our pretrained models
LM = LanguageModel("abcdefghijklmnopqrstuvwxzy␣", layers=3, hidden_size=512, embedding_size=16).cpu()
LM.load_state_dict(torch.load("text8.lm.statedict.pt", map_location='cpu'))

In [None]:
# A sample from the LM (at local temperature T=0.5)
print(''.join(itertools.islice(LM.greedy_sample(temperature=0.5), 100)))

# Using no batches
LM.next_options(LM.hcs_from_context("the␣defor"))

Let's see what the model's trie looks like for our example context given above! We'll render it only to depth 4 and omit all expansions with probability $p(x_{i+1}|\mathbf{x}_{<i}) \le .18$ to keep things readable:

In [None]:
the_defor = LM.hcs_from_context("the␣defor")
LM.render_trie_from(hcs=the_defor, depth=4, prob_gt=.18)

Notice that even with the threshold, this LSTM allows some words like "defores" (not actually English) and "deformer" that we didn't anticipate.  Now, given this trie, let's try to reproduce a table like the one given above that shows the *single most likely prediction* for each length:

In [None]:
def brute_force_get_best_string(lm, hcs, length):
    """
    Returns the string and its log-probability.
    You can solve this by recursion using `lm.next_options(..., logprobs=True)`.
    """
    ### STUDENTS START
    raise NotImplementedError()  # REPLACE ME
    ### STUDENTS END


In [None]:
for l in range(1, 3):
    s, lp = brute_force_get_best_string(LM, the_defor, l)
    print(f"{l:2}: {s:5} (p={math.exp(lp)})")

You can try counting up to lengths larger than 2, but you should quickly see that even a length-3 string takes too long to brute-force maximize. But you have already learned a technique to perform (approximate) maximization in these locally normalized models: beam search!
Let's try to apply that here!

First we need to recast our problem of finding the best string in the framework of `TaskSetting`s and `ProbabilityModel`s, specifically, the `StatefulTaskSetting` and `IncrementalScoringModel`, as we did in the last homework:

In [None]:
class PredictStringTask(StatefulTaskSetting):
    """
    The task predicts strings of a certain length.
    The internal "state" will only be the length of the string that is
    still to be generated (because this is sufficient for the structural zeros).
    """
    def __init__(self, vocab, length):
        super().__init__()
        self.vocab = tuple(vocab)
        self.length = length
    
    def initial_taskstate(self, *, xx):
        return self.length
    
    def next_taskstate(self, *, xx, a, taskstate):
        return taskstate - 1
    
    def iterate_y(self, *, xx, oo=None, yy_prefix):
        assert oo is None # let's not deal with that here
        if yy_prefix > 0:
            yield from self.vocab
        else:
            yield None


In [None]:
# It yields all sequences, as expected:
list(PredictStringTask("XY", 2).iterate_aa(xx=None))

In [None]:
class LMScorer(IncrementalScoringModel, torch.nn.Module):

    def __init__(self, task, lm):
        # Always initialize the PyTorch module first, so the registration hooks work!
        torch.nn.Module.__init__(self)
        super().__init__(task)
        self.lm = lm

    def initialize_params(self):
        """
        Since all the magic happens in `self.lm`, the `LanguageModel`, nothing here.
        """
        pass

    def initial_modelstate(self, *, xx):
        """
        Initialize the LM hidden states with `xx`, our preceding context.
        We will also store the current predictions in the model state to share
        them across different to-be-probed actions.
        """
        hcs = self.lm.hcs_from_context(xx)
        preds = self.lm.next_options(hcs, logprobs=True)
        return (hcs, preds)
    
    def score_a_s(self, *, xx, a, taskstate, modelstate):
        ### STUDENTS START
        raise NotImplementedError()  # REPLACE ME
        ### STUDENTS END


Let's see what predictions we would make for lengths 1, 2, 3, ..., 10 (and whether they're consistent with the predictions we made earlier):

In [None]:
def best_string(self, *, prefix, length, beam_size=5):
    scorer = LMScorer(PredictStringTask(self.idx2char, length), self)
    agent = BeamDecisionAgent(scorer, beam_size=beam_size)
    return ''.join(agent.decision(xx=prefix))

# Patch it into the language model
LanguageModel.best_string = best_string

print("Greedy search on our old example")
for l in range(1, 10):
    s = LM.best_string(prefix="this␣is", length=l, beam_size=1)
    print(f"{l:2}: {s:5}")
print("Greedy search on a new example")
for l in range(1, 10):
    s = LM.best_string(prefix="this␣is", length=l, beam_size=1)
    print(f"{l:2}: {s:5}")
print("Beam search (beam size 10) on the same example")
for l in range(1, 10):
    s = LM.best_string(prefix="this␣is", length=l, beam_size=10)
    print(f"{l:2}: {s:5}")

## The "real" task (which we won't be solving)

Now let's try to formalize the overarching task.  Suppose our user is partway through entering the sentence $\mathbf{w} = $`sequence␣modeling␣is␣the␣best`.   A dialogue might go like this:

observed state | a few examples of possible actions | agent chooses | environment responds | reward (keystrokes saved)
-|-|-|-|-
`sequence␣modeling␣is␣` | `a` , `th` , `the` , `the␣` , `the␣s` , `the␣co` , `the␣mos` , `the␣most` | `the␣` | accept | $4-1 = 3$
`sequence␣modeling␣is␣the␣` | `s` , `co` , `fir` , `most` | `co` | `b` | $1-1 = 0$
`sequence␣modeling␣is␣the␣b`  | `e` , `es` , `est` | `est`| accept | $3-1 = 2$
`sequence␣modeling␣is␣the␣best` | $\Rightarrow$ This is a final state, so we can take no more actions and get no more reward. | &nbsp; | &nbsp; | $0$

In principle, an agent for this task could suggest any string to the user; the table above shows just a few possible actions at each time step. Our total score accross all timesteps was $3 + 0 + 2 = 5$. We of course want this score to be as high as possible.
 
Intuitively, what would a good approach to this task look like? First of all, we definitely don't want to suggest *unlikely* strings - a language model will come in handy here. Secondly, we want to suggest a lot of characters at a time, to maximize the keystrokes saved. But we need to trade this off against the chance that the user actually accepts our suggestion, so we probably don't want to suggest strings that are *too* long. 

### Formalizing the task: Markov Decision Processes

In order to map the abstract game described above into a more formal setting, we define a couple of terms.

The *state space* is the set of all possible states that the agent could be in. In this case, the state space is the set of all strings in our vocabulary, which correspond to possible values of the prefix that has already been entered. There is also an *initial state distribution*, which in our case places probability 1 on the prefix `""`.

The *action space* is the set of all actions that can be taken by the agent in any state. In this task, we can suggest any string to the user, so our action space is the set of all strings in our vocabulary. (Note that it's  unusual that the state space is the same as the action space, they are generally different.)

The *transition function* is a function that takes in a state and an action, and outputs a sample from a distribution over states. In our case, the transition function is essentially the user of the predictive keyboard: after we take an action, suggesting some autocompletion, the user either accepts that suggestion (transitioning us to the state consisting of our old state concatenated with our suggested autocompletion), or types another character instead (transitioning us to the state consisting of our old state concatenated with one new character). Or, the user presses "send" on their message, transitioning us to the *terminal state*: the state in which no futher actions are possible.

An *episode* is a sequence of state->action->state->action->....->terminal state.

The *reward* is the score given in response to taking a particular action at a particular timestep. The *return* is the sum of all rewards in an episode. These two terms are similar, so take note of the difference! In our case, the reward is the number of characters saved by the user in response to a particular suggestion, and the return is the total number of characters saved over the whole episode.

Together, the state space, initial state distribution, action space, reward function, and transition function define an *environment*. An *agent* exists in an environment and interacts with it. The agent can have many components, but one important component that every agent has is a *policy*.

The *policy* is the function which maps from states to actions. Our overall task-independent goal can be stated as follows: create an agent whose policy has high expected return. In this case, we are looking to create an agent whose policy for suggesting autocompletions saves the user the most total keystrokes.

## Simplifying things for the homework

Unfortuately, the problem setting as described above is a bit too tough for a homework assignment. Therefore, we are going to simplify the problem in several ways (some of which are unrealistic):

We assume that the user is generating their strings according to some language model **p\***. The agent may or may not know **p\***.

For any given episode, we assume oracle knowledge of the total length of the complete string that the user wants to enter. Additionally, we "summarize" any given state of the environment by a single dense vector $\mathbf{h}$, namely the hidden state of an pre-trained LSTM language model that has read the already-typed prefix associated with that state. 

Thus, our state space is a 2-tuple: $s = (\mathbf{h},k) =$ (prefix_vector, nchars_left).  Our policy will choose its action based only on this 2-tuple.  We will also predict the environment's response based only on that 2-tuple.  This is a kind of conditional independence assumption.

(Implementation note: in some places in the code, we prefer to view the state as the 2-tuple (prefix_string, nchars_left), instead. Therefore, we always actually pass around the 3-tuple (prefix_string, prefix_vector, nchars_left), and only ever use either one or the other of the first two.)

(Including "nchars_left" is completely unrealistic and shouldn't really be necessary.  We included it in the state only because our pre-trained LM was not trained on a collection of actual text messages ending in EOS , but rather on a single very long string consisting of many concatenated documents.  So it unfortunately couldn't ever learn how to predict EOS: $\mathbf{h}$ does not include *any* information about how close to the end of the user's string we are. Therefore, we cheated and augmented it with *perfect* information about this.)

We also significantly reduce our action space by assuming that we have an oracle language model that we can use, in any given state, to estimate the most probable choice of the next $a$ characters, using beam search.  We now restrict our action set to $\{ a : 1 \leq a \leq 10 \}$.  In other words, there are 10 possible actions -- we must suggest the estimated most probable string of length 1, or of length 2, or ... you get the point.

If the user is close to the end of the string, and the agent suggests a longer string than the remaining characters that the user wants to type, the environment will reject the suggestion. However, the environment is lenient regarding final spaces: if the suggestion is longer than the user's string, but all of the extra characters are spaces (which happens often), the user will still accept the suggestion. 

To start off, the reward function will be as described above: the number of characters saved by suggesting some autocompletion. Note that this means that the reward scales with the size of the suggested string, but only if the suggestion is accepted: if the suggestion is rejected, the reward is 0 regardless of what character the user types. (This reward function is pretty intuitive and easy to work with. Later on in the homework, we will also consider the case where the reward function is less "nice".)


## The actual environment (the typist)

Let's try to build an implementation for this scenario now. We will start with:
1. A definition of the environment, or more specifically, the environment in a specific state. It will, given this state, look at the action that the agent chose, and return the reward that the agent obtains for this action, as well as sample the output of the transition function and return the next state. (Also, it will update its own internal state to the next state.)
2. An `RLAgent`, that represents our agent and its policy: given a state, it decides on what action to take.

In [None]:
class RLAgent(object):
    """
    Note that this RLAgent is similar to the DecisionAgent.
    """
    
    def decision(self, *, state):
        """
        Makes a decision, based on the state that it sees.
        """
        raise NotImplementedError()
        
    def receive_response(self, *, state, action, reward, next_state):
        """
        Updates the agent's internal state using the response it received from
        the environment.
        """
        raise NotImplementedError()

In [None]:
class EnvironmentState(object):
    """
    Representing a specific state our environment is in for the current episode.
    Note that we explicitly make this object stateful to highlight that we
    cannot just take any action free of consequences.
    """
        
    def execute_action(self, *, action):
        """
        Returns the reward for an action and the new state that the agent
        has entered. Also, updates the current state of the environment.
        Needs not be deterministic.
        """
        raise NotImplementedError()

    def evaluate_agent(self, *, agent=None, agentclass=None):
        """
        Convenience method, executing an entire episode by running the agent.
        If passed an `agent`, it will just use it, if passed an `agentclass`, it
        will try to instantiate that class in some way and then use it.
        Returns the total reward, (aka the return).
        """
        raise NotImplementedError()

For implementational convenience and perhaps even realism, our environment is an idealized user who starts out by randomly sampling a *particular* string $\mathbf{w}$ that they want to type.  This string does not change as the agent acts.  The subsequent accept/reject decisions by the environment do depend on the agent's actions, but without any further randomness - they are fully determined by the agent's actions together with the previously sampled string $\mathbf{w}$.

## The agent's (imperfect) model of the environment

Since $\mathbf{w}$ is a hidden (latent) part of the environment's true state $\bar{s}$, the true setting is a POMDP setting.  But we will, for simplicity, construct our agent anyway as if the setting were an MDP with simpler states $s$ of the form $s = (\mathbf{h},k) =$ (prefix_vector, nchars_left) that the agent actually observes.

(What are the consequences of this simplification?  The agent doesn't know the true POMDP state  $\bar{s}$.  So it doesn't actually have enough information to determine the true values $\bar{Q}(\bar{s},a)$ of the different actions, even if it had the ability to do exhaustive search over the possible future rollouts.  Instead the agent is under the mistaken impression that $s$ is all it needs to know to choose its action.  As a result, its policy can only depend on the prefix typed so far, and not on the history of actions and observables that constructed that prefix.  *Example:* Suppose the user rejects the suggestion `that` by typing `t`.  The user would now *certainly* reject `hat` as the next suggestion, but the MDP agent will fail to realize that!  In contrast, a POMDP agent would be able to update its belief about the unknown state $\mathbf{s}$ to reflect the new evidence that $\mathbf{w}$ (whatever it may be!) clearly wasn't supposed to continue with `that`, and hence it now does not continue with `hat` following the `t`.)

For now, let's assume that our agent does not have to do any learning because it is given a pretty good model of the environment, in other words, a model $\hat{p}(s' \mid s,a)$ and $\hat{r}(s, a, s')$.  Here $s'$ represents the predicted next state of the MDP,  and $\hat{r}$ is the associated reward if that is the next state.   The action $a$ is the number of characters in the substring that the agent proposes; it will always correspond to the beam-search approximation of the maximum-likelihood string of length $a$ given context $s$, where the likelihood is computed under **p\***.  

Specifically, the agent's environment model assumes that the probability that the environment will accept the suggested string of length $a$, and append it to the current prefix, equals the probability according to a language model that the next $a$ characters following the current prefix would in fact be the suggested string.  This environment model might be bad if the agent's language model is not the actual distribution **p\*** from which the environment sampled $\mathbf{w}$.  But for now, let's allow the agent to use **p\*** as its language model.

In addition, let's say that the agent knows the true reward structure, so $\hat{r}$ gives the same reward for acceptance as the real environment would.

In [None]:
class TypistState(EnvironmentState):
    def __init__(self, *, lm, string=None, start_index=0, total_length=10, debug=False): 
        self.lm = lm
        self.debug = debug        
        self.reset(string=string, start_index=start_index, total_length=total_length)

    def reset(self, *, string=None, start_index=0, total_length=20):
        """
        Our state will be the already-typed part of the string.
        """
        if string is None:
            self.string = ''.join(itertools.islice(self.lm.greedy_sample(temperature=0.2), total_length))
        else:
            self.string = string
        self.current_index = start_index
        self.all_hcs = self.lm.hcs_from_context(self.string, get_all_hcs=True)
        return self.current_state

    @property
    def current_prefix(self): return self.string[:self.current_index]
    @property
    def current_hcs(self): return self.all_hcs[self.current_index]
    @property
    def current_nchars_left(self): return len(self.string) - self.current_index
    @property
    def current_state(self):
        return (self.current_prefix, self.current_hcs, self.current_nchars_left)

    def get_suggestion(self, *, action):
        return self.lm.best_string(prefix=self.current_prefix, length=action)
    
    def successful_prediction_reward(self, action): return action - 1
    def failed_prediction_reward(self, action): return 0
    
    def execute_action(self, *, action):
        """
        Given the action (a proposed string) check whether it is "correct"
        and reward accordingly.
        """
        assert 0 < action <= 10
        # What are we looking for?
        goldanswer = self.string[self.current_index : self.current_index + action]
        # What did we suggest?
        suggestion = self.get_suggestion(action=action)
        if action > self.current_nchars_left:
            # crop ending spaces
            action -= len([c for c in suggestion[self.current_nchars_left:] if c == '␣'])
            suggestion = suggestion[:self.current_nchars_left] + ''.join([c for c in suggestion[self.current_nchars_left:] if c != '␣'])
        # Check what happens
        if self.debug:
            print("If we have \"" + self.string[:self.current_index]
                  + "\" and still need \"" + self.string[self.current_index:]
                  + "\", the proposition \"" + suggestion, end='" ')
        if suggestion == goldanswer:
            # Advance the task state that far
            self.current_index += action
            reward = self.successful_prediction_reward(action)
            if self.debug:
                print("is correct!", end=' ')
        else:
            # Only advance by one
            self.current_index += 1
            reward = self.failed_prediction_reward(action)
            if self.debug:
                print("is incorrect (\"" + goldanswer + "\" would have been correct)!", end=' ')
        if self.debug:
            print("We get reward", reward, "and new state \"" + self.current_prefix + "\".")
        return reward, self.current_state
    
    def evaluate_agent(self, *, agent=None, agentclass=None, return_agent=False, reset=False, **kwargs):
        assert agent is None or agentclass is None
        if agent is None:
            agent = agentclass(
                **kwargs  # this will be the way to pass the LM
            )
            
        total_reward = 0
        if self.debug:
            print(f'Starting with state "{self.string[:self.current_index]}>{self.string[self.current_index:]}":')

        if reset: self.reset()
        while self.current_index < len(self.string):
            # Get action
            state = self.current_state
            action = agent.decision(state=state)
            if self.debug:
                print(f'Agent chooses to predict {action}!')
            # Use reward and response
            reward, next_state = self.execute_action(action=action)
            agent.receive_response(state=state, action=action, reward=reward, next_state=next_state)  # We will define this below.
            total_reward += reward
            if self.debug:
                print(f'User gave reward {reward} and response "{next_state[0]}"')
        if self.debug:
            print("Received", total_reward, "total reward!")
        return total_reward if not return_agent else (total_reward, agent)

Let's see what our possible actions are in the initial state of the run that we saw above, and in some possible subsequent states.  For each action, we'll also see what the environment *would* do if the agent took that action.  Of course, the agent will only get to take *one* of the actions.

In [None]:
for nchars in range(21, 25):
    for action in range(1, 5):
        # Set up an environment just to where we can test the action...
        env = TypistState(lm=LM, string="sequence␣modeling␣is␣the␣best", start_index=nchars, debug=True)
        env.execute_action(action=action)

## A first agent: greedily choosing actions

A first very simple agent is going to use its model to greedily choose the best action. Since our reward function gives meaningful rewards at every step, it's a sensible heuristic to always pick the action that gives the best expected immediate reward.  (This is not optimal, however, because it doesn't consider how the action affects the next state $s'$: getting to a good state $s'$ could set us up for big future rewards.  We'll see the impact of that later.)

Equipped with this model, the agent can compute the *expected immediate reward* of each action $a$.   In our case, that is the reward that the agent will receive if the user accepts $a$ times the model probability that the user will accept $a$, plus a reward of 0 times the model probability that the user will not accept $a$ but instead will type the next character.

In [None]:
def expected_immediate_reward(*, prefix, nchars_left, action, agent_lm, success_reward_fn, fail_reward_fn):
    """
    The prefix is the string that we observed so far, the LM is the
    (potentially inadequate) LM of the agent.
    """
    # Return 0 if no chars left
    if nchars_left == 0: return torch.tensor(0)
    # What would the probability of the string we would give be, i.e.,
    # what is the probability that it is correct?
    scorer = LMScorer(PredictStringTask(agent_lm.idx2char, action), agent_lm)
    aa = agent_lm.best_string(prefix=prefix, length=action)
    prob = scorer.score_aa(xx=prefix, aa=aa).exp()
    # What would we get if it was indeed correct?
    # We assume knowledge of the reward function here of course.
    positive_reward = success_reward_fn(action)
    negative_reward = fail_reward_fn(action)
    return prob * positive_reward + (1 - prob) * negative_reward

Let's actually compute those expected immediate rewards. What is the immediate reward we expect from taking any action when we are in some state?

In [None]:
prefix = "sequence␣modeling␣is␣"
for action in range(1, 6):
    er = expected_immediate_reward(prefix=prefix, nchars_left=10, action=action, agent_lm=LM,
                                   success_reward_fn=lambda action: action-1,
                                   fail_reward_fn=lambda action: 0)
    print(action, f"{er.item():.4f}")

Unsurprisingly, predicting four characters is best: if you already have "t", getting "h", "e", and "␣" isn't much of a risk -- but the reward is much higher.
Just to make sure you understand these quantities: 

- What is the minimum the reward could ever be (for any action $a \in \mathbb{N}$)? $\color{red}{\text{FILL IN}}$
- How many actions will achieve that minimum (given the environment we use here)? $\color{red}{\text{FILL IN}}$
- How far do we have to look to find the action with the highest reward? $\color{red}{\text{FILL IN}}$
- Assume $a=4$ indeed achieves the maximum. Is this enough to tell us that we should definitely pick $a=4$? Why or why not? $\color{red}{\text{FILL IN}}$

So let's try to build a greedy agent to get us good actions!

In [None]:
class TextProposalAgent(RLAgent):
    """
    These things will be shared by all the agents we will build in the sequel.
    """
    def __init__(self, *, lm=None):
        self.lm = lm
    def decision(self, *, state):
        pass    
    def receive_response(self, *, state, action, reward, next_state):
        pass

In [None]:
class GreedyExpectedImmediateRewardAgent(TextProposalAgent):
    def success_reward_fn(self, action): return action-1
    def fail_reward_fn(self, action): return 0
  
    def decision(self, *, state):
        """
        This function defines a policy: make a decision based on the triple `state`.
        """
        ### STUDENTS START
        raise NotImplementedError()  # REPLACE ME
        ### STUDENTS END

# This should say: predict 4 characters, namely "the␣"
env = TypistState(lm=LM, string="sequence␣modeling␣is␣the␣best", start_index=21)
assert GreedyExpectedImmediateRewardAgent(lm=LM).decision(state=env.current_state) == 4 # "the␣"

Now, given our Typist and our Agent, it's time to play this game!

In [None]:
# We should get a total reward of 5 here
string = "sequence␣modeling␣is␣the␣best"
start_index = 21
# Verbose
assert TypistState(lm=LM, string=string, start_index=start_index, debug=True) \
        .evaluate_agent(agent=GreedyExpectedImmediateRewardAgent(lm=LM)) == 5
# Or, shorter, using some magic and only getting the result without debug info
assert TypistState(lm=LM, string=string, start_index=start_index) \
        .evaluate_agent(agentclass=GreedyExpectedImmediateRewardAgent, lm=LM) == 5

Is that good or bad? Let's compare against a baseline that chooses randomly among the possible actions (all of which are rather good since they are high-probability under the LM):

In [None]:
# A random length agent
class RandomLengthAgent(TextProposalAgent):
    def decision(self, *, state):
        action = random.randint(1, 10)
        return action

In [None]:
# Run it a few times to see how well it does on average
rs = [
    TypistState(lm=LM, string=string, start_index=start_index).evaluate_agent(agentclass=RandomLengthAgent, lm=LM)
    for _ in range(10)
]
print("Rewards:", rs)
print("Average reward:", np.average(rs))

So our greedy agent is definitely a lot better. But let's not just evaluate this on one cherry-picked example:

In [None]:
def compare_agents(agent1class, agent2class, true_lm=LM, agent_lm=LM, n_sentences=5, total_length=20, start_index=12, verbose=True):
    """
    Sample some sentences from true_lm and compare the performance of
    the passed two agents on these sentences.
    """
    random.seed(0)
    sum_1, sum_2 = 0, 0
    for _ in range(n_sentences):
        # Draw a sentence from the LM (though at lower local temperature, so the agent model isn't 100% perfect)
        sample = ''.join(itertools.islice(true_lm.greedy_sample(temperature=0.2), total_length))
        # Test both agents
        agent1_reward = TypistState(lm=true_lm, string=sample, start_index=start_index).evaluate_agent(agentclass=agent1class, lm=agent_lm)
        agent2_reward = TypistState(lm=true_lm, string=sample, start_index=start_index).evaluate_agent(agentclass=agent2class, lm=agent_lm)
        # What happened!
        sum_1 += agent1_reward
        sum_2 += agent2_reward
        if verbose:
            print(f"On {sample[:start_index]}>{sample[start_index:]}, agent 1 got {agent1_reward}, agent 2 got {agent2_reward}.")
    return (sum_1, sum_2)

In [None]:
%time compare_agents(RandomLengthAgent, GreedyExpectedImmediateRewardAgent)

You will have noticed that this is really way too slow for us to use it in the remainder of the assignment.
That's why we will break all our nice abstractions a little to write a single (fast) function for our agent.

We will inline the beam search, but more importantly, we will *fuse* all the different searches, because they all will share the same beam search prefix!

In [None]:
def best_strings_from_hcs(self, *, hcs, max_length, beam_size=5):
    """
    Returns a list z, such that z[l] = (logprob, string, hcs) of the best
    string of length l.
    """
    beam_size = min(beam_size, len(self.idx2char))
    queue = [(torch.tensor(0.0), "", hcs)]
    returns = [queue[0]]
    while len(queue[0][1]) < max_length:
        next_queue = []
        for pscore, ptaskstate, pmodelstate in queue:
            probs = (pmodelstate[-1][0] @ self.embedding.weight.t()).log_softmax(dim=-1)
            for nscore, idx in zip(*probs.squeeze(0).topk(beam_size)):
                nmodelstate = self._hcs_from_cidx(pmodelstate, torch.tensor([idx]))
                next_queue.append((nscore + pscore, ptaskstate + self.idx2char[idx], nmodelstate))
        if len(next_queue) == 0:
            break
        next_queue.sort(key=lambda x: -float(x[0]))
        queue = next_queue[:beam_size]
        returns.append(queue[0])
    return returns

# Patch it, too, into the language model
LanguageModel.best_strings_from_hcs = best_strings_from_hcs

In [None]:
# Try it out:

print("Our old function")
start_time = time.time()
for _ in range(1):
    for a in range(1, 8+1):
        print(LM.best_string(prefix="sequence␣modeling␣is␣", length=a))
print(time.time() - start_time)

print("\nOur new function")
start_time = time.time()
hcs = LM.hcs_from_context("sequence␣modeling␣is␣")
print('\n'.join([x[1] for x in LM.best_strings_from_hcs(hcs=hcs, max_length=8)][1:]))
print(time.time() - start_time)

print("\nOur new function, much farther into the string")
start_time = time.time()
hcs = LM.hcs_from_context("sequence␣modeling␣is␣")
print('\n'.join([x[1] for x in LM.best_strings_from_hcs(hcs=hcs, max_length=50)][-10:]))
print(time.time() - start_time)

Much better! Let's use it to implement faster agents!


In [None]:
class FastGreedyExpectedImmediateRewardAgent(TextProposalAgent):
    def success_reward_fn(self, action): return action-1
    def fail_reward_fn(self, action): return 0  

    def decision(self, *, state):
        ### STUDENTS START
        raise NotImplementedError()  # REPLACE ME
        ### STUDENTS END


In [None]:
# Compare on two sentences
for sample in ["modeling␣is␣the␣best", "kare␣to␣from␣first␣g"]:
    print("Predict", sample[:12], ">", sample[12:])
    
    def get_reward(agentclass):
        random.seed(0)
        user = TypistState(lm=LM, string=sample, start_index=12)
        return user.evaluate_agent(agentclass=agentclass, lm=LM)
    # Test the greedy speedup
    %time      random_reward = get_reward(                     RandomLengthAgent)
    %time      greedy_reward = get_reward(    GreedyExpectedImmediateRewardAgent)
    %time fast_greedy_reward = get_reward(FastGreedyExpectedImmediateRewardAgent)
    assert greedy_reward == fast_greedy_reward, (greedy_reward, fast_greedy_reward)
    

By breaking our beautiful abstraction and making some strong assumptions we have made our solution about 10x faster!
Now, let's run a little comparative study:

In [None]:
%time compare_agents(RandomLengthAgent, FastGreedyExpectedImmediateRewardAgent)

This should be at least twice as fast -- nice! Let's try a few more:

In [None]:
%time sum_random, sum_greedy = compare_agents(RandomLengthAgent, FastGreedyExpectedImmediateRewardAgent, n_sentences=20)
print("Totals! Random:", sum_random, "-- Greedy:", sum_greedy)

Interestingly, sometimes the random agent seems to outperform the greedy agent. To be fair, of course, the random agent isn't quite that random -- it still extracts the best string of a given length from the LM. This is mostly an artifact of the limited scope of the problem that we are attempting to tackle in this assignment. In the "real" version of this problem, our action space would contain all strings, not just ten one-best LM samples; a random agent in that action space would be remarkably terrible. As it is, the random agent can occasionally get lucky and pick a better action sequence than our more-intelligent greedy agent.

## Simplifying things: predict from a four-letter alphabet

Because even with our speedups this is still plenty slow in a Python notebook, we will simplify the task a little: our strings are now over an alphabet of three characters: "x", "o", and our old readability-improving friend "␣". The language model will also be much smaller and thus faster.

What data are we training on? We follow established NLP practice and pretend Jason's NLP class homework 1 toy grammars are somehow meaningful.  (Established NLP practice?  Please, catch the sarcasm and read [Yoav Goldberg's great takedown](https://medium.com/@yoav.goldberg/an-adversarial-review-of-adversarial-generation-of-natural-language-409ac3378bd7) of someone outside JHU who tried to write a real paper using those grammars.)  We generate some random sentences from one of those grammars, and then replace all characters in terminal symbols with with "x" or "o" and replace spaces with "␣".  This will be our training dataset for the language model.

```
$ ./randsent holygrail.gr 350000 | tr '\nabcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ. ' ' xoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxoxox ' | head -c 10000000 | sed 's/ /␣/g' | gzip > xo_corpus.txt.gz

$ zcat xo_corpus.txt.gz | sed 's/␣/\n/g' | sort -nr | uniq -c | sort -nr
 382815 x
 102040 xoxox
  92066 oxoxx
  72006 ox
  71656 ooxx
  69657 oox
  69640 xoxooxo
  69508 xox
  69495 ooxo
  69394 xxxo
  59820 xxooxo
  59718 oxxx
  52349 xxoxox
  52225 ooxoxx
  52175 oxx
  52111 xx
  52003 xxooxxx
  37677 xxo
  [...]
```

In [None]:
retrain_small_model = False

if retrain_small_model:
    if not os.path.isfile("xo_corpus.txt.gz"):
        import urllib
        urllib.request.urlretrieve("https://sjmielke.com/tmp/xo_corpus.txt.gz", "xo_corpus.txt.gz")

    with gzip.open("xo_corpus.txt.gz", 'rt') as f:
        xo_corpus = f.read()

    train_lm("xo", xo_corpus, nlayers=3, nhid=512, embsize=32, BATCHSIZE=200, BPTTLENGTH=100)

In [None]:
# Load the pretrained model
XOLM = LanguageModel("ox␣", layers=2, hidden_size=64, embedding_size=4).cpu()
XOLM.load_state_dict(torch.load("xo.lm.statedict.pt", map_location='cpu'))

In [None]:
sample = "".join(itertools.islice(XOLM.greedy_sample(), 100))
print(sample)

Looks like we got the structure down (see the single "x"s that used to be "."s).
Just to check that this was worth it, let's rerun it on our game:

In [None]:
%time sum_random, sum_greedy = compare_agents(\
    RandomLengthAgent,\
    FastGreedyExpectedImmediateRewardAgent,\
    true_lm=XOLM,\
    agent_lm=XOLM,\
    n_sentences=100,\
    verbose=False\
)
print("Totals! Random:", sum_random, "-- Greedy:", sum_greedy)

That speed will do -- for making everything much slower and more complicated.

## From greedy actions to planning

You will have noticed that we tried to stress that this (definitely better-than-random) agent of ours is *greedy* -- and you should know that we generally think that greediness is a bad thing. Specifically, the agent sometimes fails to take actions that have low immediate reward, but "set up" the agent for future success by moving it to a good state.

To emphasize this, we will make a small change to the reward function of the task. Since we want to help the user enter the text in as few steps as possible (remember that the user has to press one key per step), let's define the reward to be $-1$ on every step (so it's really a penalty).  The total reward of an episode is now the negative total number of keystrokes.  This is the same as our old reward function (the number of keystrokes saved) plus a constant (namely  $|\mathbf{w}|$).  So the *optimal* policy should be the same.  But the reward is now distributed differently across the time steps.  With the new function, all actions have the same immediate reward, so they are all tied and the *greedy* policy has no way to choose!  


In [None]:
# update our environment, and our greedy agents
TypistState.successful_prediction_reward = lambda *args:-1
TypistState.failed_prediction_reward = lambda *args:-1

class PenaltyRewardFastGreedyExpectedImmediateRewardAgent(FastGreedyExpectedImmediateRewardAgent):
    def success_reward_fn(self, action): return -1
    def fail_reward_fn(self, action): return -1

In [None]:
%time sum_random, sum_greedy = compare_agents(\
    RandomLengthAgent,\
    FastGreedyExpectedImmediateRewardAgent,\
    true_lm=XOLM,\
    agent_lm=XOLM,\
    n_sentences=100,\
    verbose=False\
)
print("Totals under new reward function, with old agent! Random:", sum_random, "-- Greedy:", sum_greedy)

%time sum_random, sum_greedy = compare_agents(\
    RandomLengthAgent,\
    PenaltyRewardFastGreedyExpectedImmediateRewardAgent,\
    true_lm=XOLM,\
    agent_lm=XOLM,\
    n_sentences=100,\
    verbose=False\
)
print("Totals under new reward function, with new agent! Random:", sum_random, "-- Greedy:", sum_greedy)

We change the reward function, and suddenly, greedy looks terrible! Why? Well, under our new reward function, all the actions look the same, and the way the greedy agent is coded, it defaults to guessing the lowest-size action in the case of ties. So, the greedy agent always picks an action of size 1, and achieves the worst possible score. (Of course, we could mitigate this by changing the "default" behavior. But that's not really solving the issue.) How can we fix this?

(Sidenote: one question you might have is, "why not just keep the old reward function"? Well, if the environment's reward is $r$, and the agent's policy maximizes some $\hat{r} \neq r$, we are taking a risk. We typically have no guarantees that maximizing $\hat{r}$ has any good effect on $r$ at all. For this toy example, it happens to be the case that the two reward functions are equivalent, but for more complex tasks, hand-designing good reward heuristics can be tedious and difficult. We want a *generic* solution that maximizes reward in any environment, regardless of when and how rewards are doled out.)

One solution to this issue is a kind of lookahead: let's *plan* out the rest of the trajectory!  This allows us to not greedily choose the action that maximizes immediate expected reward (i.e., minimizes immediate harm under our nasty new reward function), but instead choose the action that will give us the highest expected *return* (total reward over all futrher steps).  Under our new reward function, that means trying to find short trajectories, which have a less negative return.

There is an obvious complication: to plan ahead, we would have to know what response the user will give to each of our actions. Since we don't know that, we have to again use the agent's language model $\hat{p}$ to hypothesize what the user might say -- and then we follow each path.

Take a look at this picture, showing one such *game tree* for predicting a length-3 string using the letters "a", "b", and "c":

<img width="100%" src="gametree.png" />

Each node of this tree is a *state* containing information about the string we've seen so far and the actions that took us there; the states are aligned by timestep in the to-be-predicted string. In blue we see the actions that can be taken (with the corresponding 1-best string under it), then for each action, the paths split depending on whether the user accepted the whole proposal (green arc) or rejected it (red arcs, for each possible rejection), both putting us in a new state from which we continue reasoning... until we've reached the end! As you can see, even for this little example, there are a *lot* of paths, namely as many as there are final states.

How many is that in this example? $\color{red}{\text{FILL IN}}$ 

How many is that in general for predicting $n$ characters from a set of $m$ symbols? (Hint: use recursion) $\color{red}{\text{FILL IN}}$

How many paths then would we have in our old prediction task with the old language model? $\color{red}{\text{FILL IN}}$ 

You should have arrived at an answer that more or less says "utterly infeasible". Of course, in the spirit of learning by pain, we will still try to construct an agent, that does precisely this: expand *all* paths in the game tree -- for our new, very small task, of course.

## The expected return is the expected total reward: marginalizing over all paths

How can we use this game tree to find the action that seems "best"? We are still looking to maximize reward, but this time we will not only see how much the immediate action (i.e., the immediate outgoing arc of the start state in our game tree) would give us and with which probability, but we will play each of these outcomes until the end -- only then will we know the worth of the initial action.

But how do we get from knowing a completed trajectory's return and having probabilities for environment responses (our LM) to tallying up the worth of that first initial action?

We should **marginalize** over environment responses, but **maximize** over agent choices.  That is because the agent's *policy* is to always choose the action it believes to be best. If we were using a stochastic policy, we would marginalize over the agent's choices, too.

In [None]:
class ExhaustivePlanningAgent(TextProposalAgent):
    def __init__(self, lm, lookahead_nchars):
        self.lm = lm
        self.lookahead_nchars = lookahead_nchars
  
    def success_reward_fn(self, action): return -1
    def failure_reward_fn(self, action): return -1
  
    def decision(self, *, state):
        ### STUDENTS START
        raise NotImplementedError()  # REPLACE ME
        ### STUDENTS END


In [None]:
sample = "x␣xox␣oxoxxoxo␣xx␣xoxox␣xxooxoo␣x␣xoxooxo␣oxoxx␣ooxoxx␣ooxo␣oxxoo␣x"

for maxlen in range(1, 6):
    user = TypistState(lm=XOLM, string=sample[:12+maxlen], start_index=12, debug=True)
    %time user.evaluate_agent(agentclass=ExhaustivePlanningAgent, lm=XOLM, lookahead_nchars=maxlen)
    print()

This is still embarassingly slow -- runtime scales linearly with the number of paths and that scales exponentially in the length of the to be predicted suffix:

length | paths | time (ms)
-|-|-
1 | 3 | 21
2 | 19 | 36
3 | 175 | 277
4 | 2123 | 3720
5 | 32043 | 66000
6 | 579095 | 1556000

How could we fix this? Of course, we could simplify our model --  and in a sense we already did that in the previous agents! The way we set up their internal belief state, our game tree is actually not a tree, but a DAG. Why? How many nodes does it have? $\color{red}{\text{FILL IN}}$

You should have come up with a number that is $O(|\mathbf{w}|)$ -- can we get a DAG that has $O(1)$ nodes? If yes, explain how, if no, explain why. $\color{red}{\text{FILL IN}}$

Could we do it if we didn't have the number of character left as part of our state? $\color{red}{\text{FILL IN}}$

But we could also perform approximate inference - and we will do just that, specifically, we will (finally) start using techniques from reinforcement learning to find our way through this enormous state space.

## Value functions judge states

If we could tell how good each state was, we could call off our search very early -- in fact, we could take our GreedyAgent and judge each action not only by how much immediate reward we get, but add in how good the state that we end up in will be. How should we judge the "goodness" of the next state? One useful way to evaluate a state is to ask, "If I end up in this state, what return will I get over the rest of the episode?"

Let's make this more precise. Let's call our current policy $\pi$, i.e. $a = \pi(s)$. Define the function $V^\pi$ such that $V^\pi(s)$ is be the expected return that we obtain by starting in state $s$ if we always choose actions according to $\pi$ in every future state. We can use a procedure called *policy iteration* to find the best policy.

At a high level, here's how policy iteration works. First, start off with any $\pi$. Compute $V^\pi(s)$ for all states $s$. Define a new policy $\pi'(s)$ in the following way: $\pi(s) = $the action which results in the highest value among all actions we could take in $s$. Now, this new policy $\pi'$ is guaranteed to score higher than the old $\pi$! Compute $V^{\pi'}(s)$ for this new policy, and repeat: define $\pi''=\cdots$.

This procedure should hopefully seem very intuitive. Whenever we make a change to our policy, we switch from a worse action (where "worse" means our current policy gets a lower return from this action) to a better one. Since all changes to our policy result in improvements, the overall policy must improve.

There's just one gap: how do we actually compute $V^\pi$? We can learn $V^\pi$ with an incremental process, that is, we will start with some estimate (that is most likely wrong) and then iteratively refine these estimates. How can this refinement work? Let's refer to our current estimator as just $V$. For any given state $s$ we can improve $V(s)$ by recomputing it from the states that the actions you can take from $s$ lead to.
Consider this example:

<img width="100%" src="state_to_state_with_model.png" />

In state $s$, we can take actions $a_1$ and $a_2$. What happens if we did that? This is where the model of reality that the agent has comes in (the one that in our example already told us how likely an acceptance for a given proposal was): we know that taking action $a_1$ can land us in $s_{1,1}$, $s_{1,2}$, or $s_{1,3}$ (with certain probabilities, say $p_{1,1}$, $p_{1,2}$, and $p_{1,3}$).
Then we can recompute $V(s)$ using the immediate rewards $r_{*,*}$ and the value function estimates $V(s_{*,*})$!

As in the search case, the expected return from state $s$ on will be defined by the *best* action, since at any given step, our current policy is to always choose the best action, where "best" can be conveniently defined using $V$:
$$
\begin{align}
    \text{expected-return}(s, a_1) &= \mathbb{E}[r_{1,i} + \gamma V(s_{1,i}))] = \sum_{i=1}^3 p_{1,i} \cdot (r_{1,i} + \gamma V(s_{1,i})) \\
    \text{expected-return}(s, a_2) &= \mathbb{E}[r_{1,i} + \gamma V(s_{1,i}))] = \sum_{i=1}^3 p_{1,i} \cdot (r_{1,i} + \gamma V(s_{1,i}))\\
\end{align}
$$

The *discount factor* $\gamma$ (usually, a constant like $.99$) is necessary in situations where trajectories can be infinitely long (think of cycles in some state graph), so we don't end up with infinite return. Since in our case we have finite trajectories, we technically can just set $\gamma(s) = 1$ everywhere without an issue, but in deep RL, it's typical to include a gamma even on finite-horizon problems for stability reasons. So, we set $\gamma(s) = .99$.

Note that we can hardcode $V(s)$ to be $0$ for all terminal states $s$ -- if we're done, we will not get any more rewards (and why try to learn that when we already know it).  We can identify which states are terminal by the fact that nchars_left is 0.

Now, we can finally update $V(s)$ given that we know what action we would have taken in in $s$. Let's say WLOG that it was $a_1$ that had highest $\text{expected-return}(s, a)$ out of all possible actions. We can update using the Bellman equation (also called the "Bellman backup operator"):

$$
    V(s) \leftarrow \sum_{i=1}^3 p_{1,i} \cdot (r_{1,i} + \gamma V(s_{1,i}))) = \mathbb{E}[r_{1,i} + \gamma V(s_{1,i}))]
$$

If we iterate that process for all our states $s$, we will learn a good value function $V$ in all states, and acting according to the argmax of $V$ will lead us to make *globally optimal greedy decisions*! We call this the *optimal policy*, and write it as $\pi^*$, with value function $V^*$.

Pretty cool! But wait... iterate for *all* states? Clearly that's utterly infeasible, when you have a state space as large as ours. And what should this function $V$ look like internally? A big table of all the millions of states? That can't be it... We will tackle these two problems in turn:

### Deep RL: approximating $V$ with a neural network

The simplest option for implementing a function $V$ is a big table in memory where for each $s$ a separate value is stored.
This setting is known as *tabular RL* ("tabular" is the adjectival form of "table").  While it gives many nice guarantees for convergence etc., it is not very practical for real-world AI settings.  The issue is that most problems have infeasibly large state spaces: not only can we not fit these giant tables in our computer memory, but we also lose out by not *sharing information* between related states.

The alternative to tabular reinforcement learning is sometimes called *function approximation RL*. Rather than a table-lookup, we have a function that maps arbitrary states to values.
That is where "Deep" comes into Reinforcement Learning. We choose a specific functional form for $V$: a neural network that outputs a scalar given some vector representation of the state $s$.

For our purposes this means that the agent has to encode its state into a vector somehow. Well, we are in luck: we can just use the hidden state of the pretrained language model that we are already using.
Given this vector we could build an arbitrarily complicated neural network that outputs a scalar -- we will for now just use 2-layer feedforward networks.

This of course also means that the assignment given in the equation above is meaningless: we cannot just "assign" an output value to a neural network. We will instead minimize the squared distance between the old estimate for $V(s)$ and the one given by the right-hand side of the Bellman equation.

In [None]:
class ValueFunctionApproximator(torch.nn.Module):
    def __init__(self, h_dims):
        super().__init__()
        # The main network
        self.linear1 = torch.nn.Linear(h_dims+1, 16)
        self.linear2 = torch.nn.Linear(16+1, 1)
        self.linear2.weight.data = torch.FloatTensor(1, 16+1).uniform_(-.01, .01)
        self.linear2.bias.data.fill_(0.00)
        # The more stable, since only slowly and indirectly updated target network
        self.linear1_target = torch.nn.Linear(h_dims+1, 16)
        self.linear1_target.weight.data = self.linear1.weight.data
        self.linear1_target.bias.data = self.linear1.bias.data
        self.linear2_target = torch.nn.Linear(16+1, 1)
        self.linear2_target.weight.data = self.linear2.weight.data
        self.linear2_target.bias.data = self.linear2.bias.data

    def update_target_network(self):
        """
        Set the target network weights to a new moving average.
        """
        self.linear1_target.weight.data = 0.05 * self.linear1.weight.data + 0.95 * self.linear1_target.weight.data
        self.linear2_target.weight.data = 0.05 * self.linear2.weight.data + 0.95 * self.linear2_target.weight.data
        self.linear1_target.bias.data = 0.05 * self.linear1.bias.data + 0.95 * self.linear1_target.bias.data
        self.linear2_target.bias.data = 0.05 * self.linear2.bias.data + 0.95 * self.linear2_target.bias.data

    def forward(self, hcs, nchars_left, target=False):
        """
        Using PyTorch syntax, we define `forward()` to give us the scalar from
        the state representation we feed in: the LM hidden state and the number
        of characters left.
        The `target` parameter tells us whether to use the (moving average)
        target network weights.
        """
        if nchars_left <= 0: return torch.tensor(0.)
        lmrep = hcs[-1][0][0]
        inp = torch.cat([lmrep, torch.tensor([float(nchars_left)])])
        hid = (self.linear1_target if target else self.linear1)(inp).tanh()
        hid = torch.cat([hid, torch.tensor([float(nchars_left)])])
        return (self.linear2_target if target else self.linear2)(hid).squeeze()

In [None]:
# This is how we would use it:
hcs = LM.hcs_from_context("sequence␣modeling␣is␣")
V = ValueFunctionApproximator(LM.lstm_layers[-1].hidden_size)
V(hcs, 5)


We slipped in a few tricks here.  First, note that we initialize the bias of the final layer to zero, and its weights to be very small. This ensures that the initial predictions of the value function, before any training takes place, are close to zero, which improves stability by ensuring that no big wrong values get bootstrapped into other states (remember, even though none of the predictions are correct yet, each value function serves as the "right answer" for the state the timestep before it!). Secondly, we defined two versions of our neural network for predicting values.  The main network has layers `linear1` and `linear2`,  but we also have another network with the same topology, whose layers are called `linear1_target` and `linear2_target`. This *target network* will be updated more slowly than the main network, so that it doesn't oscillate during training.  Its weights will be a *moving average* of the weights of the main network.  The trick is to use this target network in place of the main network when we compute the right-hand side of the Bellman equation.  This can improve convergence rates.  Note that no gradients flow into the target network -- its weights will only change by manually averaging past weights of the main network.

In [None]:
def success_reward_fn(action): return -1
def failure_reward_fn(action): return -1


# We will need the expected return for the Bellman equation and the decisions:
def expected_returns(*, prefix_hcs, nchars_left, value_function, agent_lm, gamma=0.99, target=False):
    """
    `prefix_hcs` and `nchars_left` encode our state and will be used to find
    actions and possible environment feedback, according to the `agent_lm` LM.
    Returns the proposals (a list of strings of length 1..nchars_left) and a 1-D
    tensor containing the expected returns for all these action proposals
    according to the `value_function`.
    """
    ### STUDENTS START
    raise NotImplementedError()  # REPLACE ME
    ### STUDENTS END


In [None]:
# Given such returns, the loss we want to minimize is easy to define as the
# squared distance between the current V estimate and the right-hand side of the
# Bellman equation, using `expected_returns`.
def value_function_loss_for_state(*, prefix_hcs, nchars_left, value_function, agent_lm):
    # The old value of the value function
    lhs = value_function(prefix_hcs, nchars_left)
    # Consult our model to get the next states and their values
    rhs = expected_returns(
        prefix_hcs=prefix_hcs,
        nchars_left=nchars_left,
        value_function=value_function,
        agent_lm=agent_lm,
        target=True  # this is where we use the target network!
    ).max()
    # Return the squared distance between that and the current estimate.
    # Note that we are NOT doing "residual gradient learning" here (this would
    # mean also using the gradients of the right-hand side to update), because
    # this will make thing often optimize towards the wrong thing. We want to
    # update the "past" using the "future". So, we detach:
    return torch.nn.functional.mse_loss(lhs, rhs.detach())

In [None]:
# That is how we would use the function:
value_function_loss_for_state(prefix_hcs=hcs, nchars_left=3, value_function=V, agent_lm=LM)

### Learning from rollouts

The next question is: which states should we actually use for our Bellman update? Ideally we could update all states, but that's impossible here.
Since we can only interact with the true environment one action and response at a time, we can't simply update any arbitrary state: we need to actually reach it via executing actions in the environment in order to see it and update it. And even if that were not true, enumerating all states in our giant state space would still be absolutely infeasible.

The answer here is that we will *roll out* some policy in the environment: take actions in response to each state, while recording all states we pass through. Rolling out a policy in the environment, we obtain a *sample trajectory* that performs the task for a full episode from start to end. Then, given this trajectory, we update $V(s)$ for all states that are on this trajectory. 

Note that the form of the Bellman update equation makes it look a little like we are "trying out all the actions" -- but, like with the GreedyAgent, we are not testing all actions against the *true* environment, but only evaluating them using *model* of the environment that we have. For now, we are assuming that our agent's model is perfect, so we can perform the update once we've sampled a state (but that assuption is a bit unrealistic in many settings, and it will be removed in the next section of this homework).

In [None]:
class ValueFunctionExpectedReturnAgent(TextProposalAgent):
    def __init__(self, lm, value_function, exploration_policy):
        """
        `exploration_policy` is a function that given a tensor of expected returns, 
        yields the index of the action to take (usually, the best action).
        """
        self.lm = lm
        self.value_function = value_function
        self.exploration_policy = exploration_policy
        self.visited_cache = []

    def receive_response(self, *, state, action, reward, next_state):
        # save visited state information, to train on later
        _, hcs, nchars_left = state
        hcs = tuple((hc[0].detach(), hc[1].detach()) for hc in hcs)
        self.visited_cache.append(
            {
                "prefix_hcs": hcs,
                "nchars_left": nchars_left
            }
        )

    def decision(self, *, state):
        # Now pick the one with the highest expected return!
        _, hcs, nchars_left = state
        returns = expected_returns(
            prefix_hcs=hcs,
            nchars_left=nchars_left,
            value_function=self.value_function,
            agent_lm=self.lm
        )
        return self.exploration_policy(returns)

In [None]:
# Try it again -- this time, everything is super-fast since we essentially make greedy decisions
for maxlen in range(1, 10):
    user = TypistState(lm=XOLM, string=sample[:12+maxlen], start_index=12)
    vfa = ValueFunctionApproximator(XOLM.lstm_layers[-1].hidden_size)
    %time user.evaluate_agent(agentclass=ValueFunctionExpectedReturnAgent, lm=XOLM, value_function=vfa, exploration_policy=lambda t: torch.argmax(t))

We end up with a roughly linear dependence between sequence length and time, as expected -- as fast as we were with the greedy agent. Nearly there!

### We need to force exploration

What policy should we roll out in the environment? One guess might be to just use the current policy. 
But actually, it's important that we use an *exploration policy* instead: roll out some policy that we *know* to be suboptimal, in order to learn from the states that policy reaches. But, why not just use our current policy? It is the best policy we currently know about, so surely it must be the most valuable to learn about. However, if we only ever used our current policy to interact with the environment, there are many states that we would never visit. If one of those states were really good, we would never visit it to realize that it was good, so we would never know! Therefore, it's important that we use an exploration policy with a non-zero chance of visiting every possible state.

Specifically, we will use as our exploration policy an $\epsilon$-greedy version of our current policy, which is defined as follows: with probability $1-\epsilon$, take the best action, and with probability $\epsilon$, sample an action uniformly at random ($\epsilon$ is usually set to something like $.05$).

In [None]:
def eps_greedy_policy(returns, eps=.05):
    ### STUDENTS START
    raise NotImplementedError()  # REPLACE ME
    ### STUDENTS END


With this, we are ready to write the entire training loop:

In [None]:
def train(*, agentclass, params_to_optimize, agent_lm, exploration_policy, loss_function, target_network_updater, sentences, log_interval, cache_gradients, **kwargs):
    random.seed(0)
    sum_reward, sum_loss = 0, 0
    # Only update the desired parameters (i.e., the main network, not the target network)
    optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3)
    optimizer.zero_grad()
    for i, sentence in enumerate(sentences):
        # Run in environment
        env = TypistState(lm=XOLM, string=sentence, start_index=0)
        reward, agent = env.evaluate_agent(
            agentclass=agentclass,
            lm=agent_lm,
            return_agent=True,
            exploration_policy=exploration_policy,
            **kwargs
        )
        # Construct loss for function approximator
        loss = torch.sum(
            torch.stack(
                [loss_function(agent_lm=agent_lm, **d, **kwargs) for d in agent.visited_cache]
            )
        )
        # Output
        sum_reward += reward
        sum_loss += loss.item()
        if (i + 1) % log_interval == 0:
            print("Iter:", i+1)
            print("Avg reward (using exploration policy):", sum_reward / log_interval, "Avg loss:", sum_loss / log_interval)
            try:
                print("Sample argmax policy reward:", TypistState(lm=XOLM, string=sentence, start_index=0).evaluate_agent(
                    agentclass=ValueFunctionExpectedReturnAgent, lm=XOLM, value_function=agent.value_function, exploration_policy=lambda t: torch.argmax(t)))
                print("VFA example:", ' '.join([f"{ncl}:{agent.value_function(XOLM.hcs_from_context(sentence[:-ncl]), ncl).item():.3f}" for ncl in reversed(range(1, 10))]))
            except AttributeError:
                print("Sample argmax policy reward:", TypistState(lm=XOLM, string=sentence, start_index=0).evaluate_agent(
                    agentclass=QFunctionExpectedReturnAgent, lm=XOLM, q_function=agent.q_function, exploration_policy=lambda t: torch.argmax(t)))
                print("QFA example:", ' '.join([f"{ncl}:{agent.q_function(XOLM.hcs_from_context(sentence[:-ncl]), ncl).max().item():.3f}" for ncl in reversed(range(1, 10))]))
            print("Mean abs weights:", [p.abs().mean().item() for p in params_to_optimize])
            print()
            sum_reward, sum_loss = 0, 0
        # Optimize/update
        loss.backward()
        # Now apply the gradients to the target network after n iterations
        if (i + 1) % cache_gradients == 0:
            optimizer.step()
            target_network_updater()
            optimizer.zero_grad()

In [None]:
def train_vf(agent_lm, sentences, cache_gradients=1):
    """
    You can increase `cache_gradients` for increased stability...
    ...at the cost of slower convergence!
    """
    vfa = ValueFunctionApproximator(agent_lm.lstm_layers[-1].hidden_size)
    train(
        agentclass=ValueFunctionExpectedReturnAgent,
        params_to_optimize=list(vfa.linear1.parameters()) + list(vfa.linear2.parameters()),
        agent_lm=agent_lm,
        exploration_policy=eps_greedy_policy,
        value_function=vfa,
        loss_function=value_function_loss_for_state,
        target_network_updater=vfa.update_target_network,
        sentences=sentences,
        log_interval=20,
        cache_gradients=cache_gradients
    )
    return vfa

First let's try to overfit to some finite set of sentences. This is gonna take a few minutes. Feel free to reduce the number of times we train on that sentence (it is set to 500 below), but be warned - if you cut it off too early, you could get very bad results. To guarantee good results, you need to train a long time.

In [None]:
testsentence = '␣xxo␣x␣ooxo␣xxxooxx␣oxoxx␣xoxo'
vfa_overfit = train_vf(XOLM, [testsentence] * 500)

In [None]:
print("Random gets:", sum([TypistState(lm=XOLM, string=testsentence, start_index=0).evaluate_agent(agentclass=RandomLengthAgent, lm=XOLM) for _ in range(10)])/10)
print("Greedy gets:", TypistState(lm=XOLM, string=testsentence, start_index=0).evaluate_agent(agentclass=FastGreedyExpectedImmediateRewardAgent, lm=XOLM))

print(
    "After training, we get this reward using our argmax policy:",
    TypistState(lm=XOLM, string=testsentence, start_index=0).evaluate_agent(
        agentclass=ValueFunctionExpectedReturnAgent,
        lm=XOLM,
        value_function=vfa_overfit,
        exploration_policy=lambda t: torch.argmax(t)
    )
)

Nice (although we brutally overfit)! Let's do a qualitative check whether the value function learned what we wanted it to learn:

In [None]:
intermediate_sentence = '␣xxo␣x␣ooxo␣xxxooxx'

# The more characters we think are left, the higher we think the return will be:
print(' '.join([f"{vfa_overfit(XOLM.hcs_from_context(intermediate_sentence), ncl).item():.3f}" for ncl in reversed(range(1, 10))]))

# This is what our Bellman updates look like (with a well-trained model, both sides should be roughly equal):
prefix_hcs = XOLM.hcs_from_context(intermediate_sentence)
for ncl in reversed(range(1, 10)):
    returns = expected_returns(prefix_hcs=prefix_hcs, nchars_left=ncl, value_function=vfa_overfit, agent_lm=XOLM)
    print(vfa_overfit(prefix_hcs, ncl).item(), "is Bellman-updated to come closer to", returns[torch.argmax(returns)].item())

# This is how the input features are used:
print(vfa_overfit.linear1.weight[0])

Before peeking ahead, should the value function *increase* or *decrease* the further we progress through the sentence/trajectory? Why? $\color{red}{\text{FILL IN}}$

In [None]:
# Let's check for our example sentence, what the value function at every position is:
# (we should see it go down to 0)
for i in range(len(testsentence)):
    hcs = XOLM.hcs_from_context(testsentence[:i])
    ncl = len(testsentence)-i
    loss = value_function_loss_for_state(
        prefix_hcs=hcs,
        nchars_left=ncl,
        value_function=vfa_overfit,
        agent_lm=XOLM
    )
    print("After", i, "characters, we have V =", vfa_overfit(hcs, ncl).item(), "incurring a loss of", loss.item())

Well here's a fun oddity. Most of the states have quite decent value function judgements, it seems, but some of them incur a very large loss. Why didn't we learn something better?

The answer is simple: we almost never visited this state, so we never updated our value function here!  This sort of thing often happens in RL (and we're not going to try to fix it).  We can see that that is true when plotting the *stationary distribution* over states: run the agent a bunch of times and figure our which states were actually visited:

In [None]:
from collections import Counter

stationary = Counter()

for _ in range(50):
    env = TypistState(lm=XOLM, string=testsentence, start_index=0)
    _, agent = env.evaluate_agent(
        agentclass=ValueFunctionExpectedReturnAgent,
        lm=XOLM,
        return_agent=True,
        value_function=vfa_overfit,
        exploration_policy=eps_greedy_policy
    )
    for d in agent.visited_cache:
        stationary[int(d["nchars_left"])] += 1

for i in range(len(testsentence)):
    print(f"The state after {i:2} characters:", '#' * stationary[len(testsentence) - i])

We can see that there are some states that were visited very often, and some that are visited very rarely. Why is that?

$\color{red}{\text{FILL IN}}$

All right, enough with this simple overfitting. It's time to train our agent on the real data distribution: every time it enters the environment, it has to guess a new sentence! Can we still do well?

In [None]:
def sample_generator(lm, temperature=0.2):
    while True:
        yield ''.join(itertools.islice(lm.greedy_sample(temperature=temperature), 30))

We will only train on 500 sentences, that should make us perform *okay* -- but feel free to train for longer to see just how far this agent gets!

In [None]:
# Now train on some sentences from that generator
vfa_all = train_vf(XOLM, itertools.islice(sample_generator(XOLM), 500))

Looks like it trained well! Now let's compare this new great agent and our old GreedyDecisionAgent on a multitude of sentences as well:

In [None]:
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

def random_vs_vf_generator():
    for sentence in sample_generator(XOLM):
        reward_random = TypistState(lm=XOLM, string=sentence, start_index=0).evaluate_agent(
            agentclass=RandomLengthAgent,
            lm=XOLM   
        )
        reward_vf = TypistState(lm=XOLM, string=sentence, start_index=0).evaluate_agent(
            agentclass=ValueFunctionExpectedReturnAgent,
            lm=XOLM,
            value_function=vfa_all,
            exploration_policy=lambda t: torch.argmax(t)
        )
        if abs(reward_vf - reward_random) > 4:
            print("Remarkable difference: on", sentence, "we have random:", reward_random, "value functions:", reward_vf)
        yield (int(reward_random), int(reward_vf))

def plot_agent_difference(reward_pair_generator, name1, name2):
    list1, list2, list_diff = [], [], []
    for i, (reward1, reward2) in enumerate(itertools.islice(reward_pair_generator, 100)):
        if (i + 1) % 50 == 0:
            c1 = Counter(list1)
            c2 = Counter(list2)
            plt.bar(range(-30,0), [c1[i] for i in range(-30,0)], alpha=0.3)
            plt.bar(range(-30,0), [c2[i] for i in range(-30,0)], alpha=0.3)
            plt.legend([name1, name2])
            plt.show()
            cd = Counter(list_diff)
            plt.bar([i - 10 for i in range(21)], [cd[i-10] for i in range(21)])
            plt.legend([name2 + " - " + name1])
            plt.show()
        
        list1.append(reward1)
        list2.append(reward2)
        list_diff.append(reward2 - reward1)
    print("\nsums:", sum(list1), "vs.", sum(list2))

plot_agent_difference(random_vs_vf_generator(), "random", "value function")

Looks like the value functions help over the random agent! The improvement is about the same as the greedy search algorithm was able to get, back when the reward function was nice. The value function approach is able to get this same improvement over random from this new reward function, in the total absence of short-term rewards. And it does so without needing to do exponentially-slow planning.

The trade-off, of course, is the long training phase. The greedy planner and the multi-step planner could both be used immediately. The value function took a long time to train before we could actually run the agent. One way to think about this value function trade-off is we are able to *distill* our planning into our value function. Each Bellman update can be thought of as a tiny, one-step plan. When learning a value function, we invest a lot of computation up front, doing thousands of tiny plans and "storing" the outcomes of those plans - albeit in a very compressed form - by updating our value function. But once we are done training, we can query the value function and get the results of all that planning in constant time!

However, the value function approach still requires a good model of the true environment: the LM that gives the probabilities of acceptance and rejection for each proposal. In the final section, we will ask the question: what if we don't know anything about the environment we are in? But before, some small notes on how you could make this value function learning better and faster.

### Possible improvements

1. **Update on minibatches** As with gradient descent, relying on single samples means instability and slow convergence -- so usually we would sample many trajectories before performing an update on our value function.
2. **Unroll the Bellman equation** Instead of taking the value of the successor states as given by $V^\pi$ at face value, we could instead try to compute them too as the expectation over these states' successors (i.e., brute-force lookahead for another step). This yield more precise results at the cost of more computation.
3. **Sample for the Bellman equation** We have been exploiting the fact that we have a model $p$ to take an expectation over next states that our actions could get to.  This expectation may be slow to compute exactly, especially when we unroll the equation to multiple timesteps as above.  So we can instead approximate the expectation under the model by sampling.  (Note that this still uses the model of the environment rather than actual experience with the environment; it's just a speedup.)  This leads to an interesting tradeoff between unrolling depth (which helps exactness) and a smaller number of samples for each decision on this unrolled depth (if we want to keep computation constant).

## Model-free RL

Let's recap this assignment so far. First, we found that a simple greedy agent was able to take very good actions -- so long as the reward function was amenable to planning over short-term horizons. When the the reward function became more difficult, we needed to use either slow long-term planning or complicated value function approximation of the entire game tree in order to achieve the same level of performance. However, in all of these cases, we actually made the task too easy: we supplied our agent with both knowledge of the reward function and the language model that the environment actually used!

We basically handed our agent a complete (and very nearly correct!) model of the world -- a model in which, truly, search was all that was ever needed. So, really, instead of learning something about the environment (as one usually tries in RL), we just learned to search our existing perfect model of it. The fact that there are real trajectories with real rewards didn't matter to us at all.

This is a rather uncommon scenario in RL. One does not typically have access to models of this sort. If we had the data, we could try to *learn* a language model - but of course the learned model wouldn't be perfect. It should be easy to see that if the language model that the agent has and the language model that the environment uses to choose $\mathbf{w}$ differ significantly, agents trained using our current methods would have a tough time scoring well in the environment.

So, in this last section we will take away this perfect model and thus the capability of the agent to *reason* about what rewards it will get for any action. It will truly have nothing but the rewards it got from its rollouts to go on. Let's see how far this will get us. Can we still do better than the random agent?

### Learning $Q$ without modeling $p$

Say that we don't know $p$.  Where does our approach break if we no longer have this explicit environment model? We can identify two places:

1. During the Bellman update: we need to find out what rewards $r$ and states $s'$ an action *could get us into*, in order to compute the right-hand side of the Bellman update. If we can't make these inferences anymore, what signal can we rely on? Only the transitions $s \overset{a}{\rightarrow} s'$ that we actually observed in our roll-out phase!

Note that in order to compute the right-hand side of the Bellman update for value functions, it's not enough to have the outcome of one action: we need to have the outcomes of *all* actions. Since we can't "go back in time" and undo an action in the real environment, and our state-space is large enough that we'll probably never get into the same state more than once, we will almost never see the outcomes for all actions. There is a way around this: if our interactions in the real world are done according to our current policy, then we can, in fact, guarantee that the distribution of next-states that we reach after taking an action is the same as the one we want for the Bellman update.  As a result, we can use each transition to improve the match of the left-hand side to the right-hand side, by getting a stochastic gradient of the squared difference between them.

(One classic version of this algorithm, called [TD-Lambda](https://en.wikipedia.org/wiki/Temporal_difference_learning#TD-Lambda), was sucessfully used back in 1992 to [solve Backgammon in an early win for RL in 1992](https://en.wikipedia.org/wiki/TD-Gammon).)

But of course, we need to sample according to an exploration policy that tries all the actions.  This would mean that our exploration policy would then be "baked into" the estimates for $V(s')$. That is undesirable, since we only want to update $V$ with our optimal `argmax` policy  -- we want the value function to tell us what expected return we would get when we always take the best actions, not when we take actions according to an exploration policy like the $\epsilon$-greedy policy. 

The method we will introduce, Q-learning, will solve that problem nicely: with Q-learning, you can take *any* policy for exploration.  No matter what exploration policy $\pi$ you pick, you will converge to the same results, namely the Q function for the *optimal* policy $\pi^*$ ... at least in the tabular MDP case and provided that your exploration policy manages to eventually visit all the states (hence visits them over and over if you keep learning).

2. Both in the Bellman update, and during the actual exploration, we still need to decide what action our policy $\pi$ tells us to take! In the previous section, we did this by using our model to make a one-step prediction, and comparing the outcomes of the various actions. We clearly can't do that now, without a model.

Here's a solution to both problems. We will learn a *Q-function* that given a state *and action* tells us the expected return, i.e., instead of querying $V(s)$ for a state $s$ we will now query $Q(s, a)$ for some state-action-pair $(s, a)$.
If, again, we had this function, a policy is trivial to derive:
$\pi^*(s) = \mathrm{argmax}_{a \in \mathcal{A}} Q^*(s, a)$.  (No need to query a given environment model $p$ anymore.) 

Most things stay the same as before, though. We again approximate $Q$ with a neural network, we will still use the hidden states of our pretrained LM as a convenient feature representation for this function approximator, and for the eventual proposal, we will still extract the 1-best strings for the proposed length from the pretrained language model (to keep things simple for this assignment).

So let's go through the only thing that really changes: the Bellman update for this Q-function. We already stated that we are in a situation where instead of being able to expand the game tree, we only observe one trajectory:

<img width="100%" src="state_to_state_without_model.png" />

The image hints at the solution: for a given state $s$ we can always define $V(s)$ using the Q-function: the value of the state is the maximum value you can get from taking any action $a$ in this state, where the value of taking $a$ in $s$ is precisely what $Q$ gives us.  

Once we reach a state $s$, we sample an action such as $a_1$ according to the exploration policy.  As the new Bellman update we would then want a new estimate for $Q(s, a_1)$:

$$
    Q(s, a_1) = \mathbb{E}_{r,s'}[r + \gamma V(s')] = \mathbb{E}_{r,s'}[r + \gamma \max_{a'} Q(s', a')]
$$

where the expectation is over the stochastic response $r,s'$ of the environment to action $a_1$ (crucially, note that this doesn't even mention the *exploration* policy).  The next action $a'$ ranges over the actions that are available in the new state $s'$.  For example, if $s' = s_{1,3}$ then we will maximize over $a' \in \{a_{1,3,1}, a_{1,3,2}\}$.  

1. The first insight needed for Q-learning is that repeating these Bellman updates on *all* the $(s,a)$ pairs over and over again -- no matter how we explore them -- would make the $Q$ function converge to the desired $Q^{\pi^*}$ function for the optimal policy.

2. The second insight is that we can't explicitly compute the expectation on the right-hand side, because we only see the environment give us a single sample of $r,s'$.  So again, we consider the squared error between the left and the right hand sides, compute its gradient, and use our single sample to give us a stochastic estimate of the gradient of that squared error with respect to $Q(s,a_1)$.  We then take a step along that stochastic gradient.  That's Q-learning.

3. The third insight is that we are no longer in the tabular case, so we can't take a stochastic gradient step on $Q(s,a_1)$ directly to reduce the squared error.  Rather, $Q(s,a_1)$ is the output of a neural net.  So we use backprop to take a stochastic gradient step on the parameters of the neural net.  That's deep Q-learning.

All right! Let's implement it!

In [None]:
class QFunctionApproximator(torch.nn.Module):
    def __init__(self, h_dims):
        super().__init__()
        # The value network
        self.v_linear1 = torch.nn.Linear(h_dims+1, 16)
        self.v_linear2 = torch.nn.Linear(16+1, 1)
        self.v_linear2.weight.data = torch.FloatTensor(1, 16+1).uniform_(-.01, .01)
        self.v_linear2.bias.data.fill_(0.00)
        # The more stable, since only slowly and indirectly updated target network
        self.v_linear1_target = torch.nn.Linear(h_dims+1, 16)
        self.v_linear1_target.weight.data = self.v_linear1.weight.data
        self.v_linear1_target.bias.data = self.v_linear1.bias.data
        self.v_linear2_target = torch.nn.Linear(16+1, 1)
        self.v_linear2_target.weight.data = self.v_linear2.weight.data
        self.v_linear2_target.bias.data = self.v_linear2.bias.data
        # The action-value network
        self.a_linear1 = torch.nn.Linear(h_dims+1, 16)
        self.a_linear2 = torch.nn.Linear(16+1, 11)
        self.a_linear2.weight.data = torch.FloatTensor(11, 16+1).uniform_(-.01, .01)
        self.a_linear2.bias.data.fill_(0.00)
        # The more stable, since only slowly and indirectly updated target network
        self.a_linear1_target = torch.nn.Linear(h_dims+1, 16)
        self.a_linear1_target.weight.data = self.a_linear1.weight.data
        self.a_linear1_target.bias.data = self.a_linear1.bias.data
        self.a_linear2_target = torch.nn.Linear(16+1, 11)
        self.a_linear2_target.weight.data = self.a_linear2.weight.data
        self.a_linear2_target.bias.data = self.a_linear2.bias.data

    def update_target_network(self):
        """
        Set the target network weights to a new moving average.
        """
        self.v_linear1_target.weight.data = 0.05 * self.v_linear1.weight.data + 0.95 * self.v_linear1_target.weight.data
        self.v_linear2_target.weight.data = 0.05 * self.v_linear2.weight.data + 0.95 * self.v_linear2_target.weight.data
        self.v_linear1_target.bias.data = 0.05 * self.v_linear1.bias.data + 0.95 * self.v_linear1_target.bias.data
        self.v_linear2_target.bias.data = 0.05 * self.v_linear2.bias.data + 0.95 * self.v_linear2_target.bias.data
        self.a_linear1_target.weight.data = 0.05 * self.a_linear1.weight.data + 0.95 * self.a_linear1_target.weight.data
        self.a_linear2_target.weight.data = 0.05 * self.a_linear2.weight.data + 0.95 * self.a_linear2_target.weight.data
        self.a_linear1_target.bias.data = 0.05 * self.a_linear1.bias.data + 0.95 * self.a_linear1_target.bias.data
        self.a_linear2_target.bias.data = 0.05 * self.a_linear2.bias.data + 0.95 * self.a_linear2_target.bias.data

    def forward(self, hcs, nchars_left, target=False):
        """
        Using PyTorch syntax, we define `forward()` to give us the scalar from
        the state representation we feed in: the LM hidden state and the number
        of characters left.
        The `target` parameter tells us whether to use the (moving average)
        target network weights.
        """
        if nchars_left <= 0: return torch.tensor([0.]*11)
        lmrep = hcs[-1][0][0]
        inp = torch.cat([lmrep, torch.tensor([float(nchars_left)])])
        hid = (self.v_linear1_target if target else self.v_linear1)(inp).tanh()
        hid = torch.cat([hid, torch.tensor([float(nchars_left)])])
        v = (self.v_linear2_target if target else self.v_linear2)(hid)
        
        hid = (self.a_linear1_target if target else self.a_linear1)(inp).tanh()
        hid = torch.cat([hid, torch.tensor([float(nchars_left)])])
        a = (self.a_linear2_target if target else self.a_linear2)(hid)
        a = a / torch.mean(a[1:])
        a[0] = torch.tensor(-float('inf'))
        
        return v + a



In [None]:
# This is how we would use it:
prefix_hcs, nchars_left = LM.hcs_from_context("sequence␣modeling␣is␣"), 8
Q = QFunctionApproximator(LM.lstm_layers[-1].hidden_size)
print(Q(
    hcs=prefix_hcs,
    nchars_left=nchars_left,
))
# To get the value of predicting an 8-prefix:
print(Q(
    hcs=prefix_hcs,
    nchars_left=nchars_left,
)[8])

In [None]:
def q_function_loss_for_state_action_reward_state(*, hcs, nchars_left, action, reward, next_hcs, next_nchars_left, q_function, agent_lm=None, gamma=0.99):
    """
    Both `hcs_state` and `nchars_left` have versions 1 and 2 for the two states
    we need here. Since we chose to encode the action by its final LM hidden
    state and its length, there is no need to feed in any more information about
    it, as the former is simple `hcs_state2` and the latter can be calculated as
    `nchars_left1 - nchars_left2`.
    We do, however, need to pass in the immediate reward that we received.
    Note that the `agent_lm` is only used to calculate all possible actions from
    state2, not to do model-based calculations!
    """
    ### STUDENTS START
    raise NotImplementedError()  # REPLACE ME
    ### STUDENTS END


In [None]:
# That is how we would use the function:
env = TypistState(lm=LM, string="sequence␣modeling␣is␣the␣best", start_index=20)
_, hcs, nchars_left = env.current_state
reward, (_, next_hcs, next_nchars_left) = env.execute_action(action=4)
q_function_loss_for_state_action_reward_state(
    hcs=hcs,
    nchars_left=nchars_left,
    action=4,
    reward=reward,
    next_hcs=next_hcs,
    next_nchars_left=next_nchars_left,
    q_function=Q
)

In [None]:
class QFunctionExpectedReturnAgent(TextProposalAgent):
    def __init__(self, q_function, exploration_policy, lm=None):
        self.q_function = q_function
        self.exploration_policy = exploration_policy
        self.visited_cache = []

    def receive_response(self, state, action, reward, next_state):
        """
        This time we can only append a completed dict to the `visisted_cache` list!
        """
        # Assemble tuple
        _, hcs, nchars_left = state
        _, next_hcs, next_nchars_left = next_state
        # Append
        self.visited_cache.append(
            {
                "hcs": hcs,
                "nchars_left": nchars_left,
                "action": action,
                "reward": reward,
                "next_hcs": next_hcs,
                "next_nchars_left": next_nchars_left
            }
        )

    def decision(self, *, state):
        """
        No need to append to our cache here (unlike the VF agent), but do
        remember to use the `self.exploration_policy`!
        """
        _, hcs, nchars_left = state
        hcs = tuple((hc[0].detach(), hc[1].detach()) for hc in hcs)
        ### STUDENTS START
        raise NotImplementedError()  # REPLACE ME
        ### STUDENTS END


In [None]:
def train_qf(base_lm, sentences, cache_gradients=1):
    qf = QFunctionApproximator(base_lm.lstm_layers[-1].hidden_size)
    train(
        agentclass=QFunctionExpectedReturnAgent,
        params_to_optimize=list(qf.v_linear1.parameters()) + list(qf.v_linear2.parameters()) + list(qf.a_linear1.parameters()) + list(qf.a_linear2.parameters()),
        agent_lm=None,
        exploration_policy=eps_greedy_policy,
        q_function=qf,
        loss_function=q_function_loss_for_state_action_reward_state,
        target_network_updater=qf.update_target_network,
        sentences=sentences,
        log_interval=50,
        cache_gradients=cache_gradients
    )
    return qf

Again, test if we can overfit this one sentence. This time, we will probably need a lot more samples to get a reasonable estimate because we only learn from our trajectories. In real settings, "more samples" means millions and billions -- our code in this notebook already takes a long time for thousands. But maybe the simplicity of our problem makes this still feasible?

In [None]:
testsentence = '␣xxo␣x␣ooxo␣xxxooxx␣oxoxx␣xoxo'
qf_overfit = train_qf(XOLM, [testsentence] * 1000)

In [None]:
# Let's check for our example sentence, what the maximum Q value at every position is:
# (we should see it go up to 0)
for i in range(len(testsentence)):
    hcs = XOLM.hcs_from_context(testsentence[:i])
    ncl = len(testsentence)-i
    hyps = XOLM.best_strings_from_hcs(hcs=hcs, max_length=ncl)
    maxq = qf_overfit(
                hcs=hcs,
                nchars_left=ncl
            ).max().item()
    print("After", i, "characters, we have a maximum Q of", maxq)

In [None]:
print("Random gets:", TypistState(lm=XOLM, string=testsentence, start_index=0).evaluate_agent(agentclass=RandomLengthAgent, lm=XOLM))

print(
    "After training, we get this reward using our argmax policy:",
    TypistState(lm=XOLM, string=testsentence, start_index=0).evaluate_agent(
        agentclass=QFunctionExpectedReturnAgent,
        lm=None,
        q_function=qf_overfit,
        exploration_policy=lambda t: torch.argmax(t)
    )
)

Also for the final test, we can again try to train on the true distribution:

In [None]:
qf_all = train_qf(XOLM, itertools.islice(sample_generator(XOLM), 2000))

Time for the final test: how well does it perform against the random agent overall?

In [None]:
def random_vs_q_generator():
    for sentence in sample_generator(XOLM):
        reward_random = TypistState(lm=XOLM, string=sentence, start_index=0).evaluate_agent(
            agentclass=RandomLengthAgent,
            lm=XOLM
        )
        reward_qf = TypistState(lm=XOLM, string=sentence, start_index=0).evaluate_agent(
            agentclass=QFunctionExpectedReturnAgent,
            lm=None,
            q_function=qf_all,
            exploration_policy=lambda t: torch.argmax(t)
        )
        if abs(reward_qf - reward_random) > 4:
            print("Remarkable difference: on", sentence,
                  "we have random:", reward_random, "vs. Q-functions:", reward_qf)
        yield (reward_random, reward_qf)

%matplotlib inline
plot_agent_difference(random_vs_q_generator(), "random", "q function")

You should hopefully see some decent numbers: the agent outperforms the random baseline by about the same amount as the value function agent did (which is also around the same amount as the greedy agent with nice rewards outperformed the random agent). But this time, we were able to reach that score without any model at all!

Hopefully, you can see why Q-learning would be much more general than the previous approaches. Consider a task like teaching a robot to wade through a river. In order to get a "perfect" model of the world, we would need to solve an incredibly complicated system involving the fluid dynamics of the river, the torque of the joints, the varying effects of pressure at different depths, the chance the robot gets hit by a passing fish, and a million more things. The chance that we could perfectly model this scenario are slim. But if we use model-free Q-learning, we can sidestep all of those thorny issues and just learn a policy directly!

Of course, Q-learning is not without its own issues. It is very sample inefficient, and deep Q-learning specifically is very unstable and difficult to get working. (Hence the 2 releases of this homework.) So, there is still a lot of research that needs to be done before we can actually learn practical robot gaits via deep Q-learning. But since it is so general, the method has a ton of potential! 

# Congrats! You've won the game.

That's it for this assignment! If you want, play around with the reinforcement agent a bit more: try other policies and see how this has a bigger effect on $V$ than on $Q$, try to use more complicated networks, try to speed things up by batching... But probably, if you are to attempt to use RL "in the wild", you will -- like with most things in this class -- choose to write more task-specific and optimized code, but hopefully you have gotten a little glimpse of how reinforcement learning isn't all about robots moving -- but can be instructive to think about in other settings, too!