diff --git a/demos/lstm/README.md b/demos/lstm/README.md new file mode 100644 index 0000000000..0bba1fc6ed --- /dev/null +++ b/demos/lstm/README.md @@ -0,0 +1,25 @@ +# Learning digits of pi using an LSTM + +Demonstrates training a simple autoregressive LSTM network in Tensorflow and +then porting that model to deeplearn.js. + +This network uses two ``BasicLSTMCell``s combined with ``MultiRNNCell``. The +network is trained to memorize the first few digits of pi. + +First, train the LSTM network with Tensorflow: + +``` +python demos/lstm/train.py +``` + +Next, export the weights to be used by deeplearn.js: + +``` +python scripts/dump_checkpoint_vars.py --output_dir=demos/lstm/ --checkpoint_file=/tmp/simple_lstm-1000 --remove_variables_regex=".*Adam.*|.*beta.*" +``` + +Finally, start the demo: + +``` +scripts/watch-demo demos/lstm/lstm.ts +``` diff --git a/demos/lstm/fully_connected_biases b/demos/lstm/fully_connected_biases new file mode 100644 index 0000000000..9f8fb8cdf4 --- /dev/null +++ b/demos/lstm/fully_connected_biases @@ -0,0 +1 @@ +,:MP>ƳmZKfN<>ֻGd6H \ No newline at end of file diff --git a/demos/lstm/fully_connected_weights b/demos/lstm/fully_connected_weights new file mode 100644 index 0000000000..8061f0c24f Binary files /dev/null and b/demos/lstm/fully_connected_weights differ diff --git a/demos/lstm/index.html b/demos/lstm/index.html new file mode 100644 index 0000000000..82bb347816 --- /dev/null +++ b/demos/lstm/index.html @@ -0,0 +1,25 @@ + + + + LSTM Demo + + +

LSTM demo

+
Expected:
+
+
Results:
+
+
+ + + diff --git a/demos/lstm/lstm_inference.ts b/demos/lstm/lstm_inference.ts new file mode 100644 index 0000000000..7e38e127d0 --- /dev/null +++ b/demos/lstm/lstm_inference.ts @@ -0,0 +1,79 @@ +/* Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {Array1D, Array2D, CheckpointLoader, NDArrayMathGPU, Scalar, + util} from '../deeplearnjs'; + +// manifest.json lives in the same directory. +const reader = new CheckpointLoader('.'); +reader.getAllVariables().then(vars => { + const primerData = 3; + const expected = [1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4]; + const math = new NDArrayMathGPU(); + + const lstmKernel1 = vars[ + 'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'] as Array2D; + const lstmBias1 = vars[ + 'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'] as Array1D; + + const lstmKernel2 = vars[ + 'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'] as Array2D; + const lstmBias2 = vars[ + 'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'] as Array1D; + + const fullyConnectedBiases = vars['fully_connected/biases'] as Array1D; + const fullyConnectedWeights = vars['fully_connected/weights'] as Array2D; + + const results: number[] = []; + + math.scope((keep, track) => { + const forgetBias = track(Scalar.new(1.0)); + const lstm1 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel1, + lstmBias1); + const lstm2 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel2, + lstmBias2); + + let c = [track(Array2D.zeros([1, lstmBias1.shape[0] / 4])), + track(Array2D.zeros([1, lstmBias2.shape[0] / 4]))]; + let h = [track(Array2D.zeros([1, lstmBias1.shape[0] / 4])), + track(Array2D.zeros([1, lstmBias2.shape[0] / 4]))]; + + let input = primerData; + for (let i = 0; i < expected.length; i++) { + const onehot = track(Array2D.zeros([1, 10])); + onehot.set(1.0, 0, input); + + const output = math.multiRNNCell([lstm1, lstm2], onehot, c, h); + + c = output[0]; + h = output[1]; + + const outputH = h[1]; + const weightedResult = math.matMul(outputH, fullyConnectedWeights); + const logits = math.add( weightedResult, fullyConnectedBiases); + + const result = math.argMax(logits).get(); + results.push(result); + input = result; + } + }); + document.getElementById('expected').innerHTML = '' + expected; + document.getElementById('results').innerHTML = '' + results; + if(util.arraysEqual(expected, results)) { + document.getElementById('success').innerHTML = 'Success!'; + } else { + document.getElementById('success').innerHTML = 'Failure.'; + } +}); diff --git a/demos/lstm/manifest.json b/demos/lstm/manifest.json new file mode 100644 index 0000000000..13a1f25732 --- /dev/null +++ b/demos/lstm/manifest.json @@ -0,0 +1,41 @@ +{ + "fully_connected/biases": { + "filename": "fully_connected_biases", + "shape": [ + 10 + ] + }, + "fully_connected/weights": { + "filename": "fully_connected_weights", + "shape": [ + 20, + 10 + ] + }, + "rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias": { + "filename": "rnn_multi_rnn_cell_cell_0_basic_lstm_cell_bias", + "shape": [ + 80 + ] + }, + "rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel": { + "filename": "rnn_multi_rnn_cell_cell_0_basic_lstm_cell_kernel", + "shape": [ + 30, + 80 + ] + }, + "rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias": { + "filename": "rnn_multi_rnn_cell_cell_1_basic_lstm_cell_bias", + "shape": [ + 80 + ] + }, + "rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel": { + "filename": "rnn_multi_rnn_cell_cell_1_basic_lstm_cell_kernel", + "shape": [ + 40, + 80 + ] + } +} \ No newline at end of file diff --git a/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_bias b/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_bias new file mode 100644 index 0000000000..be3bf0d5fe Binary files /dev/null and b/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_bias differ diff --git a/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_kernel b/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_kernel new file mode 100644 index 0000000000..756009d908 Binary files /dev/null and b/demos/lstm/rnn_multi_rnn_cell_cell_0_basic_lstm_cell_kernel differ diff --git a/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_bias b/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_bias new file mode 100644 index 0000000000..4ffe9fe5b7 Binary files /dev/null and b/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_bias differ diff --git a/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_kernel b/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_kernel new file mode 100644 index 0000000000..58005430a6 Binary files /dev/null and b/demos/lstm/rnn_multi_rnn_cell_cell_1_basic_lstm_cell_kernel differ diff --git a/demos/lstm/train.py b/demos/lstm/train.py new file mode 100644 index 0000000000..bd89df8d8d --- /dev/null +++ b/demos/lstm/train.py @@ -0,0 +1,84 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Trains and Evaluates a simple LSTM network.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_string('output_dir', '/tmp/simple_lstm', + 'Directory to write checkpoint.') + +def main(unused_argv): + data = np.array( + [[3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4]]) + + tf.reset_default_graph() + + x = tf.placeholder(dtype=tf.int32, shape=[1, data.shape[1] - 1]) + y = tf.placeholder(dtype=tf.int32, shape=[1, data.shape[1] - 1]) + + NHIDDEN = 20 + NLABELS = 10 + + lstm1 = tf.contrib.rnn.BasicLSTMCell(NHIDDEN) + lstm2 = tf.contrib.rnn.BasicLSTMCell(NHIDDEN) + lstm = tf.contrib.rnn.MultiRNNCell([lstm1, lstm2]) + initial_state = lstm.zero_state(1, tf.float32) + + outputs, final_state = tf.nn.dynamic_rnn( + cell=lstm, inputs=tf.one_hot(x, NLABELS), initial_state=initial_state) + + logits = tf.contrib.layers.linear(outputs, NLABELS) + + softmax_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=y, logits=logits) + + predictions = tf.argmax(logits, axis=-1) + loss = tf.reduce_mean(softmax_cross_entropy) + + train_op = tf.train.AdamOptimizer().minimize(loss) + + sess = tf.InteractiveSession() + sess.run(tf.global_variables_initializer()) + + print('Starting training...') + NEPOCH = 1000 + for step in range(NEPOCH + 1): + loss_out, _ = sess.run([loss, train_op], + feed_dict={ + x: data[:, :-1], + y: data[:, 1:], + }) + if step % 100 == 0: + print('Loss at step {}: {}'.format(step, loss_out)) + + print('Expected data:') + print(data[:, 1:]) + print('Results:') + print(sess.run([predictions], feed_dict={x: data[:, :-1]})) + + saver = tf.train.Saver() + path = saver.save(sess, FLAGS.output_dir, global_step=step) + print('Saved checkpoint at {}'.format(path)) + + +if __name__ == '__main__': + tf.app.run(main) + diff --git a/src/math/math.ts b/src/math/math.ts index 8d693a570e..13a02935e7 100644 --- a/src/math/math.ts +++ b/src/math/math.ts @@ -21,6 +21,11 @@ import {Array1D, Array2D, Array3D, Array4D, NDArray, Scalar} from './ndarray'; export type ScopeResult = NDArray[]|NDArray|void; +export interface LSTMCell { + (data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D]; +} + + export abstract class NDArrayMath { private ndarrayScopes: NDArray[][] = []; private activeScope: NDArray[]; @@ -1127,6 +1132,100 @@ export abstract class NDArrayMath { x: Array3D, mean: Array3D|Array1D, variance: Array3D|Array1D, varianceEpsilon: number, scale?: Array3D|Array1D, offset?: Array3D|Array1D): Array3D; + + ////////////// + // LSTM ops // + ////////////// + + /** + * Computes the next states and outputs of a stack of LSTMCells. + * Each cell output is used as input to the next cell. + * This is only the forward mode. + * Derived from tf.contrib.rn.MultiRNNCell. + * @param lstmCells Array of LSTMCell functions. + * @param data The input to the cell. + * @param c Array of previous cell states. + * @param h Array of previous cell outputs. + * @return Tuple [nextCellStates, cellOutputs] + */ + multiRNNCell(lstmCells: LSTMCell[], data: Array2D, c: Array2D[], + h: Array2D[]): [Array2D[], Array2D[]] { + util.assert( + data.shape[0] === 1, + `Error in multiRNNCell: first dimension of data is ${data.shape[0]}, ` + + `but batch sizes > 1 are not yet supported.`); + const res = this.scope(() => { + let input = data; + const newStates = []; + for (let i = 0; i < lstmCells.length; i++) { + const output = lstmCells[i](input, c[i], h[i]); + newStates.push(output[0]); + newStates.push(output[1]); + input = output[1]; + } + + return newStates; + }); + const newC: Array2D[] = []; + const newH: Array2D[] = []; + for (let i = 0; i < res.length; i += 2) { + newC.push(res[i] as Array2D); + newH.push(res[i + 1] as Array2D); + } + return [newC, newH]; + } + + /** + * Computes the next state and output of a BasicLSTMCell. + * This is only the forward mode. + * Derived from tf.contrib.rnn.BasicLSTMCell. + * @param forgetBias Forget bias for the cell. + * @param lstmKernel The weights for the cell. + * @param lstmBias The biases for the cell. + * @param data The input to the cell. + * @param c Previous cell state. + * @param h Previous cell output. + * @return Tuple [nextCellState, cellOutput] + */ + basicLSTMCell(forgetBias: Scalar, lstmKernel: Array2D, lstmBias: Array1D, + data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D] { + const res = this.scope(() => { + util.assert( + data.shape[0] === 1, + `Error in multiRNNCell: first dimension of data is ` + + `${data.shape[0]}, but batch sizes > 1 are not yet supported.`); + // concat(inputs, h, 1) + // There is no concat1d, so reshape inputs and h to 3d, concat, then + // reshape back to 2d. + const data3D = data.as3D(1, 1, data.shape[1]); + const h3D = h.as3D(1, 1, h.shape[1]); + const combined3D = this.concat3D(data3D, h3D, 2); + const combined2D = combined3D.as2D(1, data.shape[1] + h.shape[1]); + + const weighted = this.matMul(combined2D, lstmKernel); + const res = this.add(weighted, lstmBias) as Array2D; + + // i = input_gate, j = new_input, f = forget_gate, o = output_gate + const i = this.slice2D(res, [0, 0], [res.shape[0], res.shape[1] / 4]); + const j = this.slice2D(res, [0, res.shape[1] / 4 * 1], + [res.shape[0], res.shape[1] / 4]); + const f = this.slice2D(res, [0, res.shape[1] / 4 * 2], + [res.shape[0], res.shape[1] / 4]); + const o = this.slice2D(res, [0, res.shape[1] / 4 * 3], + [res.shape[0], res.shape[1] / 4]); + + const newC = this.add( + this.multiplyStrict(c, + this.sigmoid(this.scalarPlusArray(forgetBias, f))), + this.multiplyStrict(this.sigmoid(i), this.tanh(j))) as Array2D; + const newH = this.multiplyStrict( + this.tanh(newC), this.sigmoid(o)) as Array2D; + + return [newC, newH]; + }); + return [res[0], res[1]]; + } + } export enum MatrixOrientation { diff --git a/src/math/math_gpu_test.ts b/src/math/math_gpu_test.ts index 8b68af880b..79abe19349 100644 --- a/src/math/math_gpu_test.ts +++ b/src/math/math_gpu_test.ts @@ -2188,3 +2188,101 @@ describe('NDArrayMathGPU debug mode', () => { expect(res.getValues()).toEqual(new Float32Array([2, NaN])); }); }); + +describe('LSTMCell', () => { + let math: NDArrayMathGPU; + beforeEach(() => { + math = new NDArrayMathGPU(); + math.startScope(); + }); + + afterEach(() => { + math.endScope(null!); + math.startScope(); + }); + + it('Batch size must be 1 for MultiRNNCell', () => { + const lstmKernel1 = Array2D.zeros([3, 4]); + const lstmBias1 = Array1D.zeros([4]); + const lstmKernel2 = Array2D.zeros([2, 4]); + const lstmBias2 = Array1D.zeros([4]); + + const forgetBias = Scalar.new(1.0); + const lstm1 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel1, + lstmBias1); + const lstm2 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel2, + lstmBias2); + + const c = [Array2D.zeros([1, lstmBias1.shape[0] / 4]), + Array2D.zeros([1, lstmBias2.shape[0] / 4])]; + const h = [Array2D.zeros([1, lstmBias1.shape[0] / 4]), + Array2D.zeros([1, lstmBias2.shape[0] / 4])]; + + const onehot = Array2D.zeros([2, 2]); + onehot.set(1.0, 1, 0); + const output = () => math.multiRNNCell([lstm1, lstm2], onehot, c, h); + expect(output).toThrowError(); + }); + + it('Batch size must be 1 for basicLSTMCell', () => { + const lstmKernel = Array2D.zeros([3, 4]); + const lstmBias = Array1D.zeros([4]); + + const forgetBias = Scalar.new(1.0); + + const c = Array2D.zeros([1, lstmBias.shape[0] / 4]); + const h = Array2D.zeros([1, lstmBias.shape[0] / 4]); + + const onehot = Array2D.zeros([2, 2]); + onehot.set(1.0, 1, 0); + const output = () => math.basicLSTMCell(forgetBias, lstmKernel, + lstmBias, onehot, c, h); + expect(output).toThrowError(); + }); + + it('MultiRNNCell with 2 BasicLSTMCells', () => { + const lstmKernel1 = Array2D.new([3, 4], new Float32Array([ + 0.26242125034332275, -0.8787832260131836, 0.781475305557251, + 1.337337851524353, 0.6180247068405151, -0.2760246992111206, + -0.11299663782119751, -0.46332040429115295, -0.1765323281288147, + 0.6807947158813477, -0.8326982855796814, 0.6732975244522095])); + const lstmBias1 = Array1D.new(new Float32Array([ + 1.090713620185852, -0.8282332420349121, 0, 1.0889357328414917])); + const lstmKernel2 = Array2D.new([2, 4], new Float32Array([ + -1.893059492111206, -1.0185645818710327, -0.6270437240600586, + -2.1829540729522705, -0.4583775997161865, -0.5454602241516113, + -0.3114445209503174, 0.8450229167938232])); + const lstmBias2 = Array1D.new(new Float32Array([ + 0.9906240105628967, 0.6248329877853394, 0, 1.0224634408950806])); + + const forgetBias = Scalar.new(1.0); + const lstm1 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel1, + lstmBias1); + const lstm2 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel2, + lstmBias2); + + const c = [Array2D.zeros([1, lstmBias1.shape[0] / 4]), + Array2D.zeros([1, lstmBias2.shape[0] / 4])]; + const h = [Array2D.zeros([1, lstmBias1.shape[0] / 4]), + Array2D.zeros([1, lstmBias2.shape[0] / 4])]; + + const onehot = Array2D.zeros([1, 2]); + onehot.set(1.0, 0, 0); + + const output = math.multiRNNCell([lstm1, lstm2], onehot, c, h); + + test_util.expectArraysClose( + output[0][0].getValues(), new Float32Array([-0.7440074682235718]), + 1e-4); + test_util.expectArraysClose( + output[0][1].getValues(), new Float32Array([0.7460772395133972]), + 1e-4); + test_util.expectArraysClose( + output[1][0].getValues(), new Float32Array([-0.5802832245826721]), + 1e-4); + test_util.expectArraysClose( + output[1][1].getValues(), new Float32Array([0.5745711922645569]), + 1e-4); + }); +}); +