This notebook goes with [this blog post](https://sgugger.github.io/pointer-cache-for-language-model.html#pointer-cache-for-language-model) that explains what the continuous cache pointer is. This technique was introduce by Grave et al. in [this article](https://arxiv.org/pdf/1612.04426.pdf).

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

This notebook uses the [fastai](https://github.com/fastai/fastai) library.

In [2]:
from fastai.text import *

Be sure to change the path to where the data is on your hard drive. The wikitext-2 can be downloaded [here](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/).

In [3]:
EOS = '<eos>'
PATH=Path('../data/wikitext')

As indicated on their website, we just had the EOS token at the end of each line.

In [4]:
def read_file(filename):
    tokens = []
    with open(PATH/filename, encoding='utf8') as f:
        for line in f:
            tokens.append(line.split() + [EOS])
    return np.array(tokens)

In [5]:
trn_tok = read_file('wiki.train.tokens')
val_tok = read_file('wiki.valid.tokens')
tst_tok = read_file('wiki.test.tokens')

In [6]:
len(trn_tok)

36718

We numericliaze the tokens into ids.

In [7]:
cnt = Counter(word for sent in trn_tok for word in sent)
itos = [o for o,c in cnt.most_common()]
itos.insert(0,'_pad_')

In [8]:
vocab_size = len(itos); vocab_size

33279

And here is the way from tokens to ids.

In [9]:
stoi = collections.defaultdict(lambda : 5, {w:i for i,w in enumerate(itos)})

In [10]:
trn_ids = np.array([([stoi[w] for w in s]) for s in trn_tok])
val_ids = np.array([([stoi[w] for w in s]) for s in val_tok])
tst_ids = np.array([([stoi[w] for w in s]) for s in tst_tok])

Thos are the parameters of our model

In [11]:
em_sz,nh,nl = 400,1150,3
drops = np.array([0.6,0.4,0.5,0.05,0.2])

This is just to create a learner object that won't be used since we don't train here.

In [12]:
bptt, bs = 5,2

In [13]:
trn_dl = LanguageModelLoader(np.concatenate(trn_ids), bs, bptt)
val_dl = LanguageModelLoader(np.concatenate(val_ids), bs, bptt)
md = LanguageModelData(PATH, 0, vocab_size, trn_dl, val_dl, bs=bs, bptt=bptt)

In [14]:
opt_fn = partial(optim.SGD, momentum=0.9)
learner= md.get_model(opt_fn, em_sz, nh, nl,
    dropouti=drops[0], dropout=drops[1], wdrop=drops[2], dropoute=drops[3], dropouth=drops[4])

The model I use as en example is stored [here](https://s3.us-east-2.amazonaws.com/sgugger/best.h5). Be sure to have the file best.h5 in a directory called models where the variable PATH points to (our replace by any model you've saved).

In [15]:
learner.load('best')

Let's begin by computing how well our model is doing before anything else. To do that we will need a way to go through all of our text, but instead of using the fastai LanguageModelLoader (who randomly modifies the bptt) we'll change the code to have a fixed bptt.

Also we don't want to do mini-batches on this validation because it resets the hidden state at each batch, making us lose valuable information. It makes a tiny bit of difference as we will see. 

In [16]:
#Comes from the LanguageModelLoader class, I just removed the minibatch and fixed the bptt.
#Now it gives an iterator that will spit bits of size bptt.
class TextReader():
    def __init__(self, nums, bptt, backwards=False):
        self.bptt,self.backwards = bptt,backwards
        self.data = self.batchify(nums)
        self.i,self.iter = 0,0
        self.n = len(self.data)

    def __iter__(self):
        self.i,self.iter = 0,0
        while self.i < self.n-1 and self.iter<len(self):
            res = self.get_batch(self.i, self.bptt)
            self.i += self.bptt
            self.iter += 1
            yield res

    def __len__(self): return self.n // self.bptt 

    def batchify(self, data):
        data = np.array(data)[:,None]
        if self.backwards: data=data[::-1]
        return T(data)

    def get_batch(self, i, seq_len):
        source = self.data
        seq_len = min(seq_len, len(source) - 1 - i)
        return source[i:i+seq_len], source[i+1:i+1+seq_len].view(-1)

This TextReader will give us an iterator that will allow us to go through the text. 

In [25]:
def my_validate(model, source, bptt=2000):
    data_source = TextReader(source, bptt)
    model.eval()
    model.reset()
    total_loss = 0.
    for inputs, targets in tqdm(data_source):
        #The language model throws up a bucnh of things, we'll focus on that later. For now we just want the ouputs.
        outputs, raws, outs = model(V(inputs))
        #The output doesn't go through softmax so we can use the CrossEntropy loss directly 
        total_loss += F.cross_entropy(outputs, V(targets), size_average=False).data[0]
    #Total size is length of our iterator times bptt
    mean = total_loss / (bptt * len(data_source))
    #Returns loss and perplexity.
    return mean, np.exp(mean)

In [26]:
my_validate(learner.model, np.concatenate(val_ids))


100%|██████████| 108/108 [00:37<00:00,  2.85it/s]


(4.304896561234085, 74.06155422085088)

This model was giving me a final validation loss of 4.317807 when it was computed with mini-batches, so we can see we gained a tiny bit by not reseting the hidden state.

In [27]:
my_validate(learner.model, np.concatenate(tst_ids))

100%|██████████| 122/122 [00:42<00:00,  2.84it/s]


(4.25155159972144, 70.21427231666625)

We will need to one-hot encode our targets so we'll use little helper function.

In [28]:
def one_hot(vec, size=vocab_size, cuda=True):
    a = torch.zeros(len(vec), size)
    for i,v in enumerate(vec):
        a[i,v] = 1.
    return V(a)

Before we write the cache_pointer, let's have a look at what our language model spits out when we send him an input. Looking at the source code of get_language_model, we see our model is of type SequentialRNN and combines an RNNEncoder and a LinearDecoder. 

SequentialRNN is just to wrap a sequence of models while keeping a reset attribute (to reset the hidden states of the RNN basically). The last model being the LinearDecoder, it's the one that will give the output, so let's have a look.

```
def forward(self, input):
    raw_outputs, outputs = input
    output = self.dropout(outputs[-1])
    decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
    result = decoded.view(-1, decoded.size(1))
    return result, raw_outputs, outputs
```
It returns three things: the result, which is the decoded version of the last output of our RNNs, and it also returns raw_outputs and outputs, which seem to come from the previous block. So let's have a look at the RNNEncoder forward function.

```
def forward(self, input):
    sl,bs = input.size()
    if bs!=self.bs:
        self.bs=bs
        self.reset()

    emb = self.encoder_with_dropout(input, dropout=self.dropoute if self.training else 0)
    emb = self.dropouti(emb)

    raw_output = emb
    new_hidden,raw_outputs,outputs = [],[],[]
    for l, (rnn,drop) in enumerate(zip(self.rnns, self.dropouths)):
        current_input = raw_output
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            raw_output, new_h = rnn(raw_output, self.hidden[l])
        new_hidden.append(new_h)
        raw_outputs.append(raw_output)
        if l != self.nlayers - 1: raw_output = drop(raw_output)
        outputs.append(raw_output)

    self.hidden = repackage_var(new_hidden)
    return raw_outputs, outputs
```
And now, we see that the raw_ouputs are the outputs (aka the hidden states) of our RNN, then ouputs is the same after dropout has been applied. We will need the real hidden state for our neural cache so we will use the raw_outputs.

This function will evaluate the model with the cache pointer on top of it. If you want to make a single prediction, you will have to adapt the code a bit. Hyperparameters values are stolen from Stephen Merity et al.

In [29]:
def my_cache_pointer(model, source, theta = 0.662, lambd = 0.1279, window=3785, bptt=2000):
    data_source = TextReader(source, bptt)
    #Set the model into eval mode.
    model.eval()
    #Just to create a hidden state.
    model.reset()
    total_loss = 0.
    #Containers for the previous targets/hidden states.
    targ_history = None
    hid_history = None
    for inputs, targets in tqdm(data_source):
        outputs, raws, outs = model(V(inputs))
        #The outputs aren't softmaxed, sowe have to do it to get the p_vocab vectors.
        p_vocab = F.softmax(outputs,1)
        #We take the last hidden states (raws contains one Tensor for the results of each layer) and remove the batch dimension.
        hiddens = raws[-1].squeeze() 
        #Start index inside our history.
        start = 0 if targ_history is None else targ_history.size(0)
        #Add the targets and hidden states to our history.
        targ_history = one_hot(targets) if targ_history is None else torch.cat([targ_history, one_hot(targets)])
        hid_history = hiddens if hid_history is None else torch.cat([hid_history, hiddens])
        for i, pv in enumerate(p_vocab):
            #Get the cached values
            p = pv
            if start + i > 0:
                targ_cache = targ_history[:start+i] if start + i <= window else targ_history[start+i-window:start+i]
                hid_cache = hid_history[:start+i] if start + i <= window else hid_history[start+i-window:start+i]
                #This is explained in the blog post.
                all_dot_prods = torch.mv(theta * hid_cache, hiddens[i])
                softmaxed = F.softmax(all_dot_prods).unsqueeze(1)
                p_cache = (softmaxed.expand_as(targ_cache) * targ_cache).sum(0).squeeze()
                p = (1-lambd) * pv + lambd * p_cache
            total_loss -= torch.log(p[targets[i]]).data[0]
        targ_history = targ_history[-window:]
        hid_history = hid_history[-window:]
    #Total size is length of our iterator times bptt
    mean = total_loss / (bptt * len(data_source))
    #Returns loss and perplexity
    return mean, np.exp(mean)

This differs a bit from the implementation of Stephen Merity et al. [here](https://github.com/salesforce/awd-lstm-lm) since they only start using the cache when they have at least windows values before, but I found slightlybetter results using it since the very beginning.  

In [30]:
my_cache_pointer(learner.model, np.concatenate(val_ids))

100%|██████████| 108/108 [21:46<00:00, 12.10s/it]


(3.9970045292146206, 54.434847575761616)

In [31]:
my_cache_pointer(learner.model, np.concatenate(tst_ids))

100%|██████████| 122/122 [24:17<00:00, 11.94s/it]


(3.95311762462693, 52.09753447104239)

So we went from 74.06/70.21 perplexity to 54.43/52.10, not so bad!