# Large Scale Language Model with Softmax Approximation

We'll show you how to define a language model which performs softmax approximation during training, and calculates exact softmax for testing.

In the [previous notebook](language_model.ipynb), 
we demonstrated how you can quickly get up and running, 
using GluonNLP to implement your own language models 
(or to grab popular off-the-shelf architectures), 
and on a variety of datasets. 

However, we left out one sticky issue that complicates the application of neural network language models to large-scale problems. The size of a neural network's softmax output layer grows with the size of the vocabulary. 
For vocabularies with 100s of thousands of words, we can quickly find ourselves dealing with output matrices containing hundreds of millions of parameters. 
Moreover, while we've seen how to get around this computational difficulty on the input side [by using word embeddings](../06_word_embedding/word_embedding.ipynb), we actually need to compute the logits for each word at each time step in order to compute the partition function. 


In this notebook, we'll teach you about some techniques for approximating the softmax distribution, using importance weighting. 

## The Problem of Language Models with Large Vocabularies

As we just described, when a word-level language model is trained on a corpus with large vocabulary size, 
the output layer easily becomes the bottleneck. 

<img src="softmax.png" width="500">

![title](archi.png)


Approach:
- Sampled softmax for training
- Full softmax for testing

Specifically, for each time step our language model must calculate 
the (un-normalized) probabilities assigned to *each word **w** in the vocabulary*,  due to the need to calculate the partition function (denominator) of the softmax function, which requires the logits of each word _**z(w)**_. 

Calculating the exact softmax could consume significant memory and computation. 
For example, the Google Billion Words(GBW) <sup>[1]</sup> dataset has a vocabulary size of ~800K. 
For a mini-batch of 256 setences unrolled 20 time steps each, the output of softmax layer alone 
takes 256 \* 20 \* 800K \* 4 bytes = 12.8 GB of memory.

One way to combat this issue is to employ a sampling-based scheme for approximating the gradient for the output layer.
We focus here on the importance-sampling scheme introduced by Bengio et al. <sup>[2]</sup>. 
The importance sampling approach trains a multi-label classifier to discriminate between true data, 
or samples from some proposal distribution. 
Instead of calculating logits for all classes, 
only those for the true data and K sampled classes are needed during training.

### Preparation

### Preparation

#### Load gluonnlp

In [1]:
import warnings
warnings.filterwarnings('ignore')

import time
import math

import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import nn, rnn

import gluonnlp as nlp
from gluonnlp import data, model
from sampler import LogUniformSampler
from utilities import detach

#### Define model architecture

In this code block, we define the embedding, encoder and decoder layers for the model. Note that `ISLogits` calculates importance sampling based logits and is used for training, while
`Dense` is used for testing. The parameters are shared by both layers.

Besides `ISLogits`, Gluon NLP also provides `NCELogits` to assist calculate NCE loss.

In [2]:
class RNNModel(gluon.Block):
    """A model with an encoder, recurrent layer, and a decoder."""
    def __init__(self, vocab_size, num_embed, num_hidden,
                 num_layers, num_sampled, dropout=0., **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.num_hidden = num_hidden
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(vocab_size, num_embed, weight_initializer=mx.init.Uniform(0.1))
        self.rnn = rnn.LSTM(num_hidden, num_layers, dropout=dropout, input_size=num_embed)

        self.sampled_decoder = model.ISDense(vocab_size, num_sampled, num_hidden)
        self.decoder = nn.Dense(vocab_size, in_units=num_hidden,
                                params=self.sampled_decoder.params)

    def forward(self, inputs, hidden, sample_mode, *args):
        emb = self.drop(self.encoder(inputs))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output).reshape((-1, self.num_hidden))
        if sample_mode:
            sampled_values, targets = args
            targets = targets.reshape((-1, 1))
            decoded, new_targets = self.sampled_decoder(output, sampled_values, targets)
            return decoded, hidden, new_targets
        else:
            decoded = self.decoder(output)
            return decoded, hidden

    def begin_state(self, *args, **kwargs):
        return self.rnn.begin_state(*args, **kwargs)

pseudo code
```
# samples pred
pred_sampled = F.FullyConnected(x, weight=w_sampled, bias=b_sampled)
# true pred
pred_true = (w_true * x).sum(axis=1)
# subtract log(q)
pred_true = F.broadcast_sub(pred_true, F.log(expected_count_true))
pred_sampled = F.broadcast_sub(pred_sampled, F.log(expected_count_sampled))
```

#### Set environment

Please set `use_gpu` to False if no GPUs are available.

In [3]:
use_gpu = True
context = mx.gpu() if use_gpu else mx.cpu()
log_interval = 50

#### Set hyperparameters for training

In [4]:
batch_size = 20
lr = 20
epochs = 3
bptt = 35
grad_clip = 0.25

#### Load dataset, extract vocabulary, numericalize, and batchify for truncated BPTT

For demonstration purpose, we load the validation set for training, and test set for validation.

In [5]:
dataset_name = 'wikitext-2'
train_dataset = data.WikiText2(segment='val', bos=None, eos='<eos>')
val_dataset = data.WikiText2(segment='test', bos=None, eos='<eos>')

vocab = nlp.Vocab(data.Counter(train_dataset), padding_token=None, bos_token=None)

bptt_batchify = nlp.data.batchify.CorpusBPTTBatchify(
    vocab, bptt, batch_size, last_batch='discard')
train_data, val_data = [
    bptt_batchify(x) for x in [train_dataset, val_dataset]
]
print(vocab)

Vocab(size=13777, unk="<unk>", reserved="['<eos>']")


#### Define parameters for model architecture

In [6]:
vocab_size = len(vocab)
num_sampled = 8192
num_hidden = 200
num_embed = 200
num_layers = 2

#### Create candidate sampler

The `LogUniformSampler` class samples classes based on the approximate log uniform or Zipfian distribution <sup>[4]</sup>. 

P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)

This sampler is useful when the true classes approximately follow such a distribution like word tokens.

Besides `LogUniformSampler`, Gluon NLP also provides other samplers such as `UnigramSampler`.

In [7]:
candidate_sampler = LogUniformSampler(vocab_size, num_sampled)

#### Create model and loss

In [8]:
net = RNNModel(vocab_size, num_embed, num_hidden, num_layers, num_sampled, dropout=0.5)

In [9]:
net.initialize(mx.init.Xavier(), ctx=context)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
parameters = net.collect_params().values()
print(net)

RNNModel(
  (encoder): Embedding(13777 -> 200, float32)
  (rnn): LSTM(200 -> 200, TNC, num_layers=2, dropout=0.5)
  (sampled_decoder): ISDense(200 -> 13777, with 8192 samples)
  (decoder): Dense(200 -> 13777, linear)
  (drop): Dropout(p = 0.5, axes=())
)


#### Evaluation with full softmax 

In [10]:
def evaluate(net, loss, data_source, batch_size, ctx):
    total_L = 0.0
    ntotal = 0
    hidden = net.begin_state(batch_size=batch_size, func=mx.nd.zeros, ctx=ctx)
    for i, (batch_data, batch_target) in enumerate(data_source):
        batch_data = batch_data.as_in_context(ctx)
        batch_target = batch_target.as_in_context(ctx)
        # evaluate full logits for testing
        sample_mode = False
        output, hidden = net(batch_data, hidden, sample_mode)
        # hidden = detach(hidden)
        L = loss(output, batch_target.reshape(-1))
        total_L += mx.nd.sum(L).asscalar()
        ntotal += L.size
    return total_L / ntotal

### Training

Now that everything is ready, we can start training the model.

In [11]:
TRAIN_BATCH_LOG = '[Epoch %d Batch %d/%d] training loss %.2f, ppl %.2f'
TRAIN_EPOCH_LOG = '[Epoch %d] Evaluation time cost %.2fs, valid loss %.2f, valid ppl %.2f'
TOTAL_LOG = 'Total training time cost %.2f s'

#### Training with importance sampled softmax

In [12]:
def train(net, loss, train_data, val_data, epochs, lr, candidate_sampler):
    start_train_time = time.time()
    for epoch in range(epochs):
        total_L = 0.0
        start_epoch_time = time.time()
        hiddens = net.begin_state(batch_size=batch_size, func=mx.nd.zeros, ctx=context)
        for i, (data, target) in enumerate(train_data):
            hiddens = detach(hiddens)
            L = 0
            with autograd.record():
                X, y = data.as_in_context(context), target.as_in_context(context)
                sample_mode = True
                # pass sampled candidates for training
                sampled_values = candidate_sampler(y)
                output, hiddens, new_target = net(X, hiddens, sample_mode, sampled_values, y)
                batch_L = loss(output, new_target)
                L = L + batch_L / X.size
            L.backward()
            grads = [p.grad(context) for p in parameters]
            gluon.utils.clip_global_norm(grads, grad_clip)
            trainer.step(1)
            total_L += L.sum().asscalar()
            if i % log_interval == 0 and i > 0:
                cur_L = total_L / log_interval
                print(TRAIN_BATCH_LOG%(epoch, i, len(train_data), cur_L, math.exp(cur_L)))
                total_L = 0.0
        mx.nd.waitall()
        val_L = evaluate(net, loss, val_data, batch_size, context)
        print(TRAIN_EPOCH_LOG%(epoch, time.time()-start_epoch_time, val_L, math.exp(val_L)))
    print(TOTAL_LOG%(time.time() - start_train_time))

#### Training loop

#### Train and evaluate

In [13]:
train(net, loss, train_data, val_data, epochs, lr, candidate_sampler)

[Epoch 0 Batch 50/309] training loss 7.80, ppl 2434.99
[Epoch 0 Batch 100/309] training loss 6.94, ppl 1029.57
[Epoch 0 Batch 150/309] training loss 6.71, ppl 819.89
[Epoch 0 Batch 200/309] training loss 6.54, ppl 691.22
[Epoch 0 Batch 250/309] training loss 6.42, ppl 612.16
[Epoch 0 Batch 300/309] training loss 6.30, ppl 547.02
[Epoch 0] Evaluation time cost 28.31s, valid loss 5.99, valid ppl 400.87
[Epoch 1 Batch 50/309] training loss 6.36, ppl 577.76
[Epoch 1 Batch 100/309] training loss 6.21, ppl 498.41
[Epoch 1 Batch 150/309] training loss 6.08, ppl 436.84
[Epoch 1 Batch 200/309] training loss 6.05, ppl 422.54
[Epoch 1 Batch 250/309] training loss 5.96, ppl 389.14
[Epoch 1 Batch 300/309] training loss 5.89, ppl 362.33
[Epoch 1] Evaluation time cost 28.10s, valid loss 5.62, valid ppl 275.18
[Epoch 2 Batch 50/309] training loss 5.98, ppl 395.63
[Epoch 2 Batch 100/309] training loss 5.88, ppl 357.46
[Epoch 2 Batch 150/309] training loss 5.76, ppl 317.76
[Epoch 2 Batch 200/309] traini

## Practice

Change `num_sampled`. How does it affect training?

## Conclusion

In this tutorial, we learnt how to build a language model with softmax approximation via importance sampling.

## Reference

[1] Chelba, Ciprian, et al. "One billion word benchmark for measuring progress in statistical language modeling." arXiv preprint arXiv:1312.3005 (2013).

[2] Bengio, Yoshua, and Jean-Sébastien Senécal. "Adaptive importance sampling to accelerate training of a neural probabilistic language model." IEEE Transactions on Neural Networks 19.4 (2008): 713-722.

[3] Jozefowicz, Rafal, et al. "Exploring the limits of language modeling." arXiv preprint arXiv:1602.02410 (2016).

[4] https://en.wikipedia.org/wiki/Zipf%27s_law