diff --git a/demos/lstm/README.md b/demos/lstm/README.md
new file mode 100644
index 0000000000..0bba1fc6ed
--- /dev/null
+++ b/demos/lstm/README.md
@@ -0,0 +1,25 @@
+# Learning digits of pi using an LSTM
+
+Demonstrates training a simple autoregressive LSTM network in Tensorflow and
+then porting that model to deeplearn.js.
+
+This network uses two ``BasicLSTMCell``s combined with ``MultiRNNCell``. The
+network is trained to memorize the first few digits of pi.
+
+First, train the LSTM network with Tensorflow:
+
+```
+python demos/lstm/train.py
+```
+
+Next, export the weights to be used by deeplearn.js:
+
+```
+python scripts/dump_checkpoint_vars.py --output_dir=demos/lstm/ --checkpoint_file=/tmp/simple_lstm-1000 --remove_variables_regex=".*Adam.*|.*beta.*"
+```
+
+Finally, start the demo:
+
+```
+scripts/watch-demo demos/lstm/lstm.ts
+```
diff --git a/demos/lstm/fully_connected_biases b/demos/lstm/fully_connected_biases
new file mode 100644
index 0000000000..9f8fb8cdf4
--- /dev/null
+++ b/demos/lstm/fully_connected_biases
@@ -0,0 +1 @@
+,:MP>ƳmZKfN<>ֻGd6H
\ No newline at end of file
diff --git a/demos/lstm/fully_connected_weights b/demos/lstm/fully_connected_weights
new file mode 100644
index 0000000000..8061f0c24f
Binary files /dev/null and b/demos/lstm/fully_connected_weights differ
diff --git a/demos/lstm/index.html b/demos/lstm/index.html
new file mode 100644
index 0000000000..82bb347816
--- /dev/null
+++ b/demos/lstm/index.html
@@ -0,0 +1,25 @@
+
+
+
+ LSTM Demo
+
+
+ LSTM demo
+ Expected:
+
+ Results:
+
+
+
+
+
diff --git a/demos/lstm/lstm_inference.ts b/demos/lstm/lstm_inference.ts
new file mode 100644
index 0000000000..7e38e127d0
--- /dev/null
+++ b/demos/lstm/lstm_inference.ts
@@ -0,0 +1,79 @@
+/* Copyright 2017 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {Array1D, Array2D, CheckpointLoader, NDArrayMathGPU, Scalar,
+ util} from '../deeplearnjs';
+
+// manifest.json lives in the same directory.
+const reader = new CheckpointLoader('.');
+reader.getAllVariables().then(vars => {
+ const primerData = 3;
+ const expected = [1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4];
+ const math = new NDArrayMathGPU();
+
+ const lstmKernel1 = vars[
+ 'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'] as Array2D;
+ const lstmBias1 = vars[
+ 'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'] as Array1D;
+
+ const lstmKernel2 = vars[
+ 'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'] as Array2D;
+ const lstmBias2 = vars[
+ 'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'] as Array1D;
+
+ const fullyConnectedBiases = vars['fully_connected/biases'] as Array1D;
+ const fullyConnectedWeights = vars['fully_connected/weights'] as Array2D;
+
+ const results: number[] = [];
+
+ math.scope((keep, track) => {
+ const forgetBias = track(Scalar.new(1.0));
+ const lstm1 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel1,
+ lstmBias1);
+ const lstm2 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel2,
+ lstmBias2);
+
+ let c = [track(Array2D.zeros([1, lstmBias1.shape[0] / 4])),
+ track(Array2D.zeros([1, lstmBias2.shape[0] / 4]))];
+ let h = [track(Array2D.zeros([1, lstmBias1.shape[0] / 4])),
+ track(Array2D.zeros([1, lstmBias2.shape[0] / 4]))];
+
+ let input = primerData;
+ for (let i = 0; i < expected.length; i++) {
+ const onehot = track(Array2D.zeros([1, 10]));
+ onehot.set(1.0, 0, input);
+
+ const output = math.multiRNNCell([lstm1, lstm2], onehot, c, h);
+
+ c = output[0];
+ h = output[1];
+
+ const outputH = h[1];
+ const weightedResult = math.matMul(outputH, fullyConnectedWeights);
+ const logits = math.add( weightedResult, fullyConnectedBiases);
+
+ const result = math.argMax(logits).get();
+ results.push(result);
+ input = result;
+ }
+ });
+ document.getElementById('expected').innerHTML = '' + expected;
+ document.getElementById('results').innerHTML = '' + results;
+ if(util.arraysEqual(expected, results)) {
+ document.getElementById('success').innerHTML = 'Success!';
+ } else {
+ document.getElementById('success').innerHTML = 'Failure.';
+ }
+});
diff --git a/demos/lstm/manifest.json b/demos/lstm/manifest.json
new file mode 100644
index 0000000000..13a1f25732
--- /dev/null
+++ b/demos/lstm/manifest.json
@@ -0,0 +1,41 @@
+{
+ "fully_connected/biases": {
+ "filename": "fully_connected_biases",
+ "shape": [
+ 10
+ ]
+ },
+ "fully_connected/weights": {
+ "filename": "fully_connected_weights",
+ "shape": [
+ 20,
+ 10
+ ]
+ },
+ "rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias": {
+ "filename": "rnn_multi_rnn_cell_cell_0_basic_lstm_cell_bias",
+ "shape": [
+ 80
+ ]
+ },
+ "rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel": {
+ "filename": "rnn_multi_rnn_cell_cell_0_basic_lstm_cell_kernel",
+ "shape": [
+ 30,
+ 80
+ ]
+ },
+ "rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias": {
+ "filename": "rnn_multi_rnn_cell_cell_1_basic_lstm_cell_bias",
+ "shape": [
+ 80
+ ]
+ },
+ "rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel": {
+ "filename": "rnn_multi_rnn_cell_cell_1_basic_lstm_cell_kernel",
+ "shape": [
+ 40,
+ 80
+ ]
+ }
+}
\ No newline at end of file
diff --git a/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_bias b/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_bias
new file mode 100644
index 0000000000..be3bf0d5fe
Binary files /dev/null and b/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_bias differ
diff --git a/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_kernel b/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_kernel
new file mode 100644
index 0000000000..756009d908
Binary files /dev/null and b/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_kernel differ
diff --git a/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_bias b/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_bias
new file mode 100644
index 0000000000..4ffe9fe5b7
Binary files /dev/null and b/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_bias differ
diff --git a/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_kernel b/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_kernel
new file mode 100644
index 0000000000..58005430a6
Binary files /dev/null and b/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_kernel differ
diff --git a/demos/lstm/train.py b/demos/lstm/train.py
new file mode 100644
index 0000000000..bd89df8d8d
--- /dev/null
+++ b/demos/lstm/train.py
@@ -0,0 +1,84 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Trains and Evaluates a simple LSTM network."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('output_dir', '/tmp/simple_lstm',
+ 'Directory to write checkpoint.')
+
+def main(unused_argv):
+ data = np.array(
+ [[3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4]])
+
+ tf.reset_default_graph()
+
+ x = tf.placeholder(dtype=tf.int32, shape=[1, data.shape[1] - 1])
+ y = tf.placeholder(dtype=tf.int32, shape=[1, data.shape[1] - 1])
+
+ NHIDDEN = 20
+ NLABELS = 10
+
+ lstm1 = tf.contrib.rnn.BasicLSTMCell(NHIDDEN)
+ lstm2 = tf.contrib.rnn.BasicLSTMCell(NHIDDEN)
+ lstm = tf.contrib.rnn.MultiRNNCell([lstm1, lstm2])
+ initial_state = lstm.zero_state(1, tf.float32)
+
+ outputs, final_state = tf.nn.dynamic_rnn(
+ cell=lstm, inputs=tf.one_hot(x, NLABELS), initial_state=initial_state)
+
+ logits = tf.contrib.layers.linear(outputs, NLABELS)
+
+ softmax_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=y, logits=logits)
+
+ predictions = tf.argmax(logits, axis=-1)
+ loss = tf.reduce_mean(softmax_cross_entropy)
+
+ train_op = tf.train.AdamOptimizer().minimize(loss)
+
+ sess = tf.InteractiveSession()
+ sess.run(tf.global_variables_initializer())
+
+ print('Starting training...')
+ NEPOCH = 1000
+ for step in range(NEPOCH + 1):
+ loss_out, _ = sess.run([loss, train_op],
+ feed_dict={
+ x: data[:, :-1],
+ y: data[:, 1:],
+ })
+ if step % 100 == 0:
+ print('Loss at step {}: {}'.format(step, loss_out))
+
+ print('Expected data:')
+ print(data[:, 1:])
+ print('Results:')
+ print(sess.run([predictions], feed_dict={x: data[:, :-1]}))
+
+ saver = tf.train.Saver()
+ path = saver.save(sess, FLAGS.output_dir, global_step=step)
+ print('Saved checkpoint at {}'.format(path))
+
+
+if __name__ == '__main__':
+ tf.app.run(main)
+
diff --git a/src/math/math.ts b/src/math/math.ts
index 8d693a570e..13a02935e7 100644
--- a/src/math/math.ts
+++ b/src/math/math.ts
@@ -21,6 +21,11 @@ import {Array1D, Array2D, Array3D, Array4D, NDArray, Scalar} from './ndarray';
export type ScopeResult = NDArray[]|NDArray|void;
+export interface LSTMCell {
+ (data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D];
+}
+
+
export abstract class NDArrayMath {
private ndarrayScopes: NDArray[][] = [];
private activeScope: NDArray[];
@@ -1127,6 +1132,100 @@ export abstract class NDArrayMath {
x: Array3D, mean: Array3D|Array1D, variance: Array3D|Array1D,
varianceEpsilon: number, scale?: Array3D|Array1D,
offset?: Array3D|Array1D): Array3D;
+
+ //////////////
+ // LSTM ops //
+ //////////////
+
+ /**
+ * Computes the next states and outputs of a stack of LSTMCells.
+ * Each cell output is used as input to the next cell.
+ * This is only the forward mode.
+ * Derived from tf.contrib.rn.MultiRNNCell.
+ * @param lstmCells Array of LSTMCell functions.
+ * @param data The input to the cell.
+ * @param c Array of previous cell states.
+ * @param h Array of previous cell outputs.
+ * @return Tuple [nextCellStates, cellOutputs]
+ */
+ multiRNNCell(lstmCells: LSTMCell[], data: Array2D, c: Array2D[],
+ h: Array2D[]): [Array2D[], Array2D[]] {
+ util.assert(
+ data.shape[0] === 1,
+ `Error in multiRNNCell: first dimension of data is ${data.shape[0]}, ` +
+ `but batch sizes > 1 are not yet supported.`);
+ const res = this.scope(() => {
+ let input = data;
+ const newStates = [];
+ for (let i = 0; i < lstmCells.length; i++) {
+ const output = lstmCells[i](input, c[i], h[i]);
+ newStates.push(output[0]);
+ newStates.push(output[1]);
+ input = output[1];
+ }
+
+ return newStates;
+ });
+ const newC: Array2D[] = [];
+ const newH: Array2D[] = [];
+ for (let i = 0; i < res.length; i += 2) {
+ newC.push(res[i] as Array2D);
+ newH.push(res[i + 1] as Array2D);
+ }
+ return [newC, newH];
+ }
+
+ /**
+ * Computes the next state and output of a BasicLSTMCell.
+ * This is only the forward mode.
+ * Derived from tf.contrib.rnn.BasicLSTMCell.
+ * @param forgetBias Forget bias for the cell.
+ * @param lstmKernel The weights for the cell.
+ * @param lstmBias The biases for the cell.
+ * @param data The input to the cell.
+ * @param c Previous cell state.
+ * @param h Previous cell output.
+ * @return Tuple [nextCellState, cellOutput]
+ */
+ basicLSTMCell(forgetBias: Scalar, lstmKernel: Array2D, lstmBias: Array1D,
+ data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D] {
+ const res = this.scope(() => {
+ util.assert(
+ data.shape[0] === 1,
+ `Error in multiRNNCell: first dimension of data is ` +
+ `${data.shape[0]}, but batch sizes > 1 are not yet supported.`);
+ // concat(inputs, h, 1)
+ // There is no concat1d, so reshape inputs and h to 3d, concat, then
+ // reshape back to 2d.
+ const data3D = data.as3D(1, 1, data.shape[1]);
+ const h3D = h.as3D(1, 1, h.shape[1]);
+ const combined3D = this.concat3D(data3D, h3D, 2);
+ const combined2D = combined3D.as2D(1, data.shape[1] + h.shape[1]);
+
+ const weighted = this.matMul(combined2D, lstmKernel);
+ const res = this.add(weighted, lstmBias) as Array2D;
+
+ // i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ const i = this.slice2D(res, [0, 0], [res.shape[0], res.shape[1] / 4]);
+ const j = this.slice2D(res, [0, res.shape[1] / 4 * 1],
+ [res.shape[0], res.shape[1] / 4]);
+ const f = this.slice2D(res, [0, res.shape[1] / 4 * 2],
+ [res.shape[0], res.shape[1] / 4]);
+ const o = this.slice2D(res, [0, res.shape[1] / 4 * 3],
+ [res.shape[0], res.shape[1] / 4]);
+
+ const newC = this.add(
+ this.multiplyStrict(c,
+ this.sigmoid(this.scalarPlusArray(forgetBias, f))),
+ this.multiplyStrict(this.sigmoid(i), this.tanh(j))) as Array2D;
+ const newH = this.multiplyStrict(
+ this.tanh(newC), this.sigmoid(o)) as Array2D;
+
+ return [newC, newH];
+ });
+ return [res[0], res[1]];
+ }
+
}
export enum MatrixOrientation {
diff --git a/src/math/math_gpu_test.ts b/src/math/math_gpu_test.ts
index 8b68af880b..79abe19349 100644
--- a/src/math/math_gpu_test.ts
+++ b/src/math/math_gpu_test.ts
@@ -2188,3 +2188,101 @@ describe('NDArrayMathGPU debug mode', () => {
expect(res.getValues()).toEqual(new Float32Array([2, NaN]));
});
});
+
+describe('LSTMCell', () => {
+ let math: NDArrayMathGPU;
+ beforeEach(() => {
+ math = new NDArrayMathGPU();
+ math.startScope();
+ });
+
+ afterEach(() => {
+ math.endScope(null!);
+ math.startScope();
+ });
+
+ it('Batch size must be 1 for MultiRNNCell', () => {
+ const lstmKernel1 = Array2D.zeros([3, 4]);
+ const lstmBias1 = Array1D.zeros([4]);
+ const lstmKernel2 = Array2D.zeros([2, 4]);
+ const lstmBias2 = Array1D.zeros([4]);
+
+ const forgetBias = Scalar.new(1.0);
+ const lstm1 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel1,
+ lstmBias1);
+ const lstm2 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel2,
+ lstmBias2);
+
+ const c = [Array2D.zeros([1, lstmBias1.shape[0] / 4]),
+ Array2D.zeros([1, lstmBias2.shape[0] / 4])];
+ const h = [Array2D.zeros([1, lstmBias1.shape[0] / 4]),
+ Array2D.zeros([1, lstmBias2.shape[0] / 4])];
+
+ const onehot = Array2D.zeros([2, 2]);
+ onehot.set(1.0, 1, 0);
+ const output = () => math.multiRNNCell([lstm1, lstm2], onehot, c, h);
+ expect(output).toThrowError();
+ });
+
+ it('Batch size must be 1 for basicLSTMCell', () => {
+ const lstmKernel = Array2D.zeros([3, 4]);
+ const lstmBias = Array1D.zeros([4]);
+
+ const forgetBias = Scalar.new(1.0);
+
+ const c = Array2D.zeros([1, lstmBias.shape[0] / 4]);
+ const h = Array2D.zeros([1, lstmBias.shape[0] / 4]);
+
+ const onehot = Array2D.zeros([2, 2]);
+ onehot.set(1.0, 1, 0);
+ const output = () => math.basicLSTMCell(forgetBias, lstmKernel,
+ lstmBias, onehot, c, h);
+ expect(output).toThrowError();
+ });
+
+ it('MultiRNNCell with 2 BasicLSTMCells', () => {
+ const lstmKernel1 = Array2D.new([3, 4], new Float32Array([
+ 0.26242125034332275, -0.8787832260131836, 0.781475305557251,
+ 1.337337851524353, 0.6180247068405151, -0.2760246992111206,
+ -0.11299663782119751, -0.46332040429115295, -0.1765323281288147,
+ 0.6807947158813477, -0.8326982855796814, 0.6732975244522095]));
+ const lstmBias1 = Array1D.new(new Float32Array([
+ 1.090713620185852, -0.8282332420349121, 0, 1.0889357328414917]));
+ const lstmKernel2 = Array2D.new([2, 4], new Float32Array([
+ -1.893059492111206, -1.0185645818710327, -0.6270437240600586,
+ -2.1829540729522705, -0.4583775997161865, -0.5454602241516113,
+ -0.3114445209503174, 0.8450229167938232]));
+ const lstmBias2 = Array1D.new(new Float32Array([
+ 0.9906240105628967, 0.6248329877853394, 0, 1.0224634408950806]));
+
+ const forgetBias = Scalar.new(1.0);
+ const lstm1 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel1,
+ lstmBias1);
+ const lstm2 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel2,
+ lstmBias2);
+
+ const c = [Array2D.zeros([1, lstmBias1.shape[0] / 4]),
+ Array2D.zeros([1, lstmBias2.shape[0] / 4])];
+ const h = [Array2D.zeros([1, lstmBias1.shape[0] / 4]),
+ Array2D.zeros([1, lstmBias2.shape[0] / 4])];
+
+ const onehot = Array2D.zeros([1, 2]);
+ onehot.set(1.0, 0, 0);
+
+ const output = math.multiRNNCell([lstm1, lstm2], onehot, c, h);
+
+ test_util.expectArraysClose(
+ output[0][0].getValues(), new Float32Array([-0.7440074682235718]),
+ 1e-4);
+ test_util.expectArraysClose(
+ output[0][1].getValues(), new Float32Array([0.7460772395133972]),
+ 1e-4);
+ test_util.expectArraysClose(
+ output[1][0].getValues(), new Float32Array([-0.5802832245826721]),
+ 1e-4);
+ test_util.expectArraysClose(
+ output[1][1].getValues(), new Float32Array([0.5745711922645569]),
+ 1e-4);
+ });
+});
+