# Concise Implementation of Recurrent Neural Networks
:label:`sec_rnn-concise`

While :numref:`sec_rnn_scratch` was instructive to see how RNNs are implemented,
this is not convenient or fast.
This section will show how to implement the same language model more efficiently
using functions provided by high-level APIs
of a deep learning framework.
We begin as before by reading the time machine dataset.


In [1]:
%use @file[../djl.json]
%use lets-plot
@file:DependsOn("../D2J-1.0-SNAPSHOT.jar")
//import jp.live.ugai.d2j.attention.Chap10Utils
import jp.live.ugai.d2j.timemachine.TimeMachine
import jp.live.ugai.d2j.timemachine.Vocab
import jp.live.ugai.d2j.timemachine.RNNModelScratch
import jp.live.ugai.d2j.SeqDataLoader
import jp.live.ugai.d2j.util.StopWatch
import jp.live.ugai.d2j.util.Accumulator
import jp.live.ugai.d2j.util.Training
import kotlin.random.Random
import kotlin.Pair
import kotlin.collections.List
// %load ../utils/djl-imports
// %load ../utils/plot-utils
// %load ../utils/PlotUtils.java

// %load ../utils/Accumulator.java
// %load ../utils/Animator.java
// %load ../utils/Functions.java
// %load ../utils/StopWatch.java
// %load ../utils/Training.java
// %load ../utils/timemachine/Vocab.java
// %load ../utils/timemachine/RNNModelScratch.java
// %load ../utils/timemachine/TimeMachine.java

In [2]:
import ai.djl.training.dataset.Record;

In [3]:
val manager = NDManager.newBaseManager();

## Creating a Dataset in DJL

In DJL, the ideal and concise way of dealing with datasets, is to use the built-in datasets that can easily wrap around existing NDArrays or to create your own dataset that extends from the `RandomAccessDataset` class. For this section, we will be implementing our own. For more information on creating your own dataset in DJL, you can refer to: https://djl.ai/docs/development/how_to_use_dataset.html

Our implementation of `TimeMachineDataset` will be a concise replacement of the `SeqDataLoader` class previously created. Using a dataset in DJL format, will allow us to use already built-in functions so we don't have to implement most things from scratch. We have to implement a Builder, a prepare function which will contain the process to save the data to the TimeMachineDataset object, and finally a get function.

In [4]:
class TimeMachineDataset(builder: Builder) : RandomAccessDataset(builder) {
    var vocab: Vocab? = null
    private var data: NDArray
    private var labels: NDArray
    private val numSteps: Int
    private val maxTokens: Int
    private val batchSize: Int
    private val manager: NDManager?
    private var prepared: Boolean

    init {
        numSteps = builder.numSteps
        maxTokens = builder.maxTokens
        batchSize = builder.sampler.batchSize
        manager = builder.manager
        data = manager!!.create(Shape(0, 35), DataType.INT32)
        labels = manager.create(Shape(0, 35), DataType.INT32)
        prepared = false
    }

    override fun get(manager: NDManager, index: Long): Record {
        val X = data[NDIndex("{}", index)]
        val Y = labels[NDIndex("{}", index)]
        return Record(NDList(X), NDList(Y))
    }

    override fun availableSize(): Long {
        return data.shape[0]
    }

    override fun prepare(progress: Progress?) {
        if (prepared) {
            return
        }
        var corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens)
        val corpus: List<Int> = corpusVocabPair.first
        vocab = corpusVocabPair.second

        // Start with a random offset (inclusive of `numSteps - 1`) to partition a
        // sequence
        val offset: Int = Random.nextInt(numSteps)
        val numTokens = ((corpus.size - offset - 1) / batchSize) * batchSize
        var Xs = manager!!.create(corpus.subList(offset, offset + numTokens).toIntArray())
        var Ys = manager.create(corpus.subList(offset + 1, offset + 1 + numTokens).toIntArray())
        Xs = Xs.reshape(Shape(batchSize.toLong(), -1))
        Ys = Ys.reshape(Shape(batchSize.toLong(), -1))
        val numBatches = Xs.shape[1].toInt() / numSteps
        val xNDList = NDList()
        val yNDList = NDList()
        var i = 0
        while (i < numSteps * numBatches) {
            val X = Xs[NDIndex(":, {}:{}", i, i + numSteps)]
            val Y = Ys[NDIndex(":, {}:{}", i, i + numSteps)]
            xNDList.add(X)
            yNDList.add(Y)
            i += numSteps
        }
        data = NDArrays.concat(xNDList)
        xNDList.close()
        labels = NDArrays.concat(yNDList)
        yNDList.close()
        prepared = true
    }

    class Builder : BaseBuilder<Builder>() {
        var numSteps = 0
        var maxTokens = 0
        var manager: NDManager? = null
        override fun self(): Builder {
            return this
        }

        fun setSteps(steps: Int): Builder {
            numSteps = steps
            return this
        }

        fun setMaxTokens(maxTokens: Int): Builder {
            this.maxTokens = maxTokens
            return this
        }

        fun setManager(manager: NDManager): Builder {
            this.manager = manager
            return this
        }

        fun build(): TimeMachineDataset {
            return TimeMachineDataset(this)
        }
    }
}

Consequently we will update our code from the previous section for the functions `predictCh8`, `trainCh8`, `trainEpochCh8`, and `gradClipping` to include the dataset logic and also allow the functions to accept an `AbstractBlock` from DJL instead of just accepting `RNNModelScratch`.

In [5]:
/** Generate new characters following the `prefix`. */
fun predictCh8(
    prefix: String,
    numPreds: Int,
    net: Any,
    vocab: Vocab,
    device: Device,
    manager: NDManager
): String {
    val outputs: MutableList<Int> = ArrayList()
    outputs.add(vocab.getIdx("" + prefix[0]))
    val getInput = {
        manager.create(outputs[outputs.size - 1])
            .toDevice(device, false)
            .reshape(Shape(1, 1))
    }
    if (net is RNNModelScratch) {
        val castedNet = net
        var state: NDList = castedNet.beginState(1, device)
        for (c in prefix.substring(1).toCharArray()) { // Warm-up period
            state = castedNet.forward(getInput(), state).second
            outputs.add(vocab.getIdx("" + c))
        }
        var y: NDArray
        for (i in 0 until numPreds) {
            val pair = castedNet.forward(getInput(), state)
            y = pair.first
            state = pair.second
            outputs.add(y.argMax(1).reshape(Shape(1)).getLong(0L).toInt())
        }
    } else {
        val castedNet = net as AbstractBlock
        var state: NDList? = null
        for (c in prefix.substring(1).toCharArray()) { // Warm-up period
            state = if (state == null) {
                // Begin state
                castedNet
                    .forward(
                        ParameterStore(manager, false),
                        NDList(getInput()),
                        false
                    )
                    .subNDList(1)
            } else {
                castedNet
                    .forward(
                        ParameterStore(manager, false),
                        NDList(getInput()).addAll(state),
                        false
                    )
                    .subNDList(1)
            }
            outputs.add(vocab.getIdx("" + c))
        }
        var y: NDArray
        for (i in 0 until numPreds) {
            val pair = castedNet.forward(
                ParameterStore(manager, false),
                NDList(getInput()).addAll(state),
                false
            )
            y = pair[0]
            state = pair.subNDList(1)
            outputs.add(y.argMax(1).reshape(Shape(1)).getLong(0L).toInt())
        }
    }
    val output = StringBuilder()
    for (i in outputs) {
        output.append(vocab.idxToToken[i])
    }
    return output.toString()
}

In [16]:
/** Clip the gradient. */
fun gradClipping(net: Any, theta: Int, manager: NDManager) {
    var result = 0.0
    val params: NDList
    if (net is RNNModelScratch) {
        params = net.params
    } else {
        params = NDList()
        for (pair in (net as AbstractBlock).parameters) {
            params.add(pair.value.array)
        }
    }
    for (p in params) {
        val gradient = p.gradient.stopGradient()
        gradient.attach(manager)
        result += gradient.pow(2).sum().getFloat()
    }
    val norm = Math.sqrt(result)
    if (norm > theta) {
        for (param in params) {
            val gradient = param.gradient
            gradient.muli(theta / norm)
        }
    }
}

In [17]:
/** Train a model within one epoch. */
fun trainEpochCh8(
    net: Any,
    dataset: RandomAccessDataset,
    loss: Loss,
    updater: (Int, NDManager) -> Unit,
    device: Device,
    useRandomIter: Boolean,
    manager: NDManager
): Pair<Double, Double> {
    val watch = StopWatch()
    watch.start()
    val metric = Accumulator(2) // Sum of training loss, no. of tokens
    manager.newSubManager().use { childManager ->
        var state: NDList? = null
        for (batch in dataset.getData(childManager)) {
            var X = batch.data.head().toDevice(device, true)
            val Y = batch.labels.head().toDevice(device, true)
            if (state == null || useRandomIter) {
                // Initialize `state` when either it is the first iteration or
                // using random sampling
                if (net is RNNModelScratch) {
                    state = net.beginState(X.shape.shape[0].toInt(), device)
                }
            } else {
                for (s in state) {
                    s.stopGradient()
                }
            }
            state?.attach(childManager)
            var y = Y.transpose().reshape(Shape(-1))
            X = X.toDevice(device, false)
            y = y.toDevice(device, false)
            Engine.getInstance().newGradientCollector().use { gc ->
                val yHat: NDArray
                if (net is RNNModelScratch) {
                    val pairResult = net.forward(X, state!!)
                    yHat = pairResult.first
                    state = pairResult.second
                } else {
                    val pairResult: NDList
                    pairResult = if (state == null) {
                        // Begin state
                        (net as AbstractBlock)
                            .forward(
                                ParameterStore(manager, false),
                                NDList(X),
                                true
                            )
                    } else {
                        (net as AbstractBlock)
                            .forward(
                                ParameterStore(manager, false),
                                NDList(X).addAll(state),
                                true
                            )
                    }
                    yHat = pairResult[0]
                    state = pairResult.subNDList(1)
                }
                val l = loss.evaluate(NDList(y), NDList(yHat)).mean()
                gc.backward(l)
                metric.add(floatArrayOf(l.getFloat() * y.size(), y.size().toFloat()))
            }
            gradClipping(net, 1, childManager)
            updater(1, childManager) // Since the `mean` function has been invoked
        }
    }
    return Pair(Math.exp((metric.get(0) / metric.get(1)).toDouble()), metric.get(1) / watch.stop())
}


In [18]:
/** Train a model. */
fun trainCh8(
    net: Any,
    dataset: RandomAccessDataset,
    vocab: Vocab,
    lr: Float,
    numEpochs: Int,
    device: Device,
    useRandomIter: Boolean,
    manager: NDManager?
) {
    val loss = SoftmaxCrossEntropyLoss()
//    val animator = Animator()
    val updater: (Int, NDManager) -> Unit = if (net is RNNModelScratch) {
        { batchSize: Int, subManager: NDManager ->
            Training.sgd(net.params, lr.toFloat(), batchSize, subManager)
        }
    } else {
        { batchSize: Int, subManager: NDManager ->
            // Already initialized net
            val castedNet = net as AbstractBlock
            val model: Model = Model.newInstance("model")
            model.block = castedNet
            val lrt: Tracker = Tracker.fixed(lr)
            val sgd: Optimizer = Optimizer.sgd().setLearningRateTracker(lrt).build()
            val config: DefaultTrainingConfig = DefaultTrainingConfig(loss)
                .optOptimizer(sgd) // Optimizer (loss function)
                .optInitializer(
                    NormalInitializer(0.01f),
                    Parameter.Type.WEIGHT
                ) // setting the initializer
                .optDevices(Engine.getInstance().getDevices(1)) // setting the number of GPUs needed
                .addEvaluator(Accuracy()) // Model Accuracy
                .addTrainingListeners(*TrainingListener.Defaults.logging()) // Logging
            val trainer: Trainer = model.newTrainer(config)
            trainer.step()
        }
    }
    val predict: (String) -> String =
        { prefix ->
            predictCh8(prefix, 50, net, vocab, device, manager!!)
        }
    // Train and predict
    var ppl = 0.0
    var speed = 0.0
    for (epoch in 0 until numEpochs) {
        val pair = trainEpochCh8(net, dataset, loss, updater, device, useRandomIter, manager!!)
        ppl = pair.first
        speed = pair.second
        if ((epoch + 1) % 10 == 0) {
//            animator.add(epoch + 1, ppl.toFloat(), "")
//            animator.show()
            println("${epoch + 1} : $ppl")
        }
    }
    println(
        "perplexity: %.1f, %.1f tokens/sec on %s%n".format(ppl, speed, device.toString())
    )
    println(predict("time traveller"))
    println(predict("traveller"))
}

Now we will leverage the dataset that we just created and assign the required parameters.

In [19]:
val batchSize = 32
val numSteps = 35

    val dataset: TimeMachineDataset = TimeMachineDataset.Builder()
        .setManager(manager)
        .setMaxTokens(10000)
        .setSampling(batchSize, false)
        .setSteps(numSteps)
        .build()
    dataset.prepare()
    val vocab = dataset.vocab

## Defining the Model

High-level APIs provide implementations of recurrent neural networks.
We construct the recurrent neural network layer `rnn_layer` with a single hidden layer and 256 hidden units.
In fact, we have not even discussed yet what it means to have multiple layers---this will happen in :numref:`sec_deep_rnn`.
For now, suffice it to say that multiple layers simply amount to the output of one layer of RNN being used as the input for the next layer of RNN.


In [20]:
    val numHiddens = 256
    val rnnLayer = RNN.builder()
        .setNumLayers(1)
        .setStateSize(numHiddens)
        .optReturnState(true)
        .optBatchFirst(false)
        .build()

Initializing the hidden state is straightforward.
We invoke the member function `beginState` _(In DJL we don't have to run `beginState` to later specify the resulting state the first time we run `forward`, as this logic is ran by DJL the first time we do `forward` but we will create it here for demonstration purposes)_.
This returns a list (`state`)
that contains
an initial hidden state
for each example in the minibatch,
whose shape is
(number of hidden layers, batch size, number of hidden units).
For some models 
to be introduced later 
(e.g., long short-term memory),
such a list also
contains other information.

In [21]:
fun beginState(batchSize: Int, numLayers: Int, numHiddens: Int): NDList {
    return NDList(manager.zeros(Shape(numLayers.toLong(), batchSize.toLong(), numHiddens.toLong())))
}

val state = beginState(batchSize, 1, numHiddens)
    println(state.size)
    println(state[0].shape)

1
(1, 32, 256)


With a hidden state and an input,
we can compute the output with
the updated hidden state.
It should be emphasized that
the "output" (`Y`) of `rnnLayer`
does *not* involve computation of output layers:
it refers to 
the hidden state at *each* time step,
and they can be used as the input
to the subsequent output layer.

Besides,
the updated hidden state (`stateNew`) returned by `rnnLayer`
refers to the hidden state
at the *last* time step of the minibatch.
It can be used to initialize the 
hidden state for the next minibatch within an epoch
in sequential partitioning.
For multiple hidden layers,
the hidden state of each layer will be stored
in this variable (`stateNew`).
For some models 
to be introduced later 
(e.g., long short-term memory),
this variable also
contains other information.

In [22]:
    val X = manager.randomUniform(0.0f, 1.0f, Shape(numSteps.toLong(), batchSize.toLong(), vocab!!.length().toLong()))

    val input = NDList(X, state[0])
    rnnLayer.initialize(manager, DataType.FLOAT32, *input.shapes)
    val forwardOutput = rnnLayer.forward(ParameterStore(manager, false), input, false)
    val Y = forwardOutput[0]
    val stateNew = forwardOutput[1]

    println(Y.shape)
    println(stateNew.shape)

(35, 32, 256)
(1, 32, 256)


Similar to :numref:`sec_rnn_scratch`,
we define an `RNNModel` class 
for a complete RNN model.
Note that `rnnLayer` only contains the hidden recurrent layers, we need to create a separate output layer.

In [23]:
class RNNModel(private val rnnLayer: RecurrentBlock, vocabSize: Int) : AbstractBlock() {
    private val dense: Linear
    private val vocabSize: Int

    init {
        this.addChildBlock("rnn", rnnLayer)
        this.vocabSize = vocabSize
        dense = Linear.builder().setUnits(vocabSize.toLong()).build()
        this.addChildBlock("linear", dense)
    }

    override fun forwardInternal(
        parameterStore: ParameterStore,
        inputs: NDList,
        training: Boolean,
        params: PairList<String, Any>?
    ): NDList {
        val X = inputs[0].transpose().oneHot(vocabSize)
        inputs[0] = X
//        println(inputs)
        val result = rnnLayer.forward(parameterStore, inputs, training)
        val Y = result[0]
        val state = result[1]
        val shapeLength = Y.shape.dimension()
        val output = dense.forward(
            parameterStore,
            NDList(Y.reshape(Shape(-1, Y.shape[shapeLength - 1]))),
            training
        )
        return NDList(output[0], state)
    }

    override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
        val shape: Shape = rnnLayer.getOutputShapes(arrayOf(inputShapes[0]))[0]
        dense.initialize(manager, dataType, Shape(vocabSize.toLong(), shape.get(shape.dimension() - 1)))
    }

    /* We won't implement this since we won't be using it but it's required as part of an AbstractBlock  */
    override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape?> {
        return arrayOfNulls<Shape>(0)
    }
}


## Training and Predicting

Before training the model, let us make a prediction with the a model that has random weights.


In [24]:
    val device = manager.device
    val net = RNNModel(rnnLayer, vocab.length())
    net.initialize(manager, DataType.FLOAT32, X.shape)
    println(predictCh8("time traveller", 10, net, vocab, device, manager))


time travellerjmmjjjjjjj


As is quite obvious, this model does not work at all. Next, we call `trainCh8` with the same hyperparameters defined in :numref:`sec_rnn_scratch` and train our model with high-level APIs.


In [25]:
    val numEpochs: Int = Integer.getInteger("MAX_EPOCH", 500)
    val lr = 1.0f
    trainCh8(net as Any, dataset, vocab, lr, numEpochs, device, false, manager)
    predictCh8("time traveller", 10, net, vocab, device, manager)

10 : 8.245257727060578
20 : 6.0875992419532245
30 : 4.595497889528389
40 : 3.822660152024894
50 : 3.2301867387474688
60 : 2.80533039619512
70 : 2.4822068849381775
80 : 2.287052411897664
90 : 2.061359146862815
100 : 1.9161627962336654
110 : 1.832307384287594
120 : 1.745760329774405
130 : 1.6944842883948232
140 : 1.6028999438217844
150 : 1.5537197947336647
160 : 1.5361287295435533
170 : 1.4799384901306276
180 : 1.4519396279346952
190 : 1.425256105226484
200 : 1.385519922477976
210 : 1.3807876103069892
220 : 1.3635620176237664
230 : 1.3409062478980134
240 : 1.308174846239329
250 : 1.3085381352780492
260 : 1.297337448067251
270 : 1.3003313775132774
280 : 1.2786158917626498
290 : 1.2667826697630045
300 : 1.240752235497341
310 : 1.2783358644199587
320 : 1.258791214111969
330 : 1.2429950474258593
340 : 1.2417753937279017
350 : 1.2404491141516438
360 : 1.221693000372478
370 : 1.2273231406785485
380 : 1.2258791047534625
390 : 1.2214231827293505
400 : 1.2357831085253768
410 : 1.2150065401582133


time traveller came what

Compared with the last section, this model achieves comparable perplexity,
albeit within a shorter period of time, due to the code being more optimized by
high-level APIs of the deep learning framework.


## Summary

* High-level APIs of the deep learning framework provides an implementation of the RNN layer.
* The RNN layer of high-level APIs returns an output and an updated hidden state, where the output does not involve output layer computation.
* Using high-level APIs leads to faster RNN training than using its implementation from scratch.

## Exercises

1. Can you make the RNN model overfit using the high-level APIs?
1. What happens if you increase the number of hidden layers in the RNN model? Can you make the model work?
1. Implement the autoregressive model of :numref:`sec_sequence` using an RNN.
