Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions demos/lstm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Learning digits of pi using an LSTM

Demonstrates training a simple autoregressive LSTM network in Tensorflow and
then porting that model to deeplearn.js.

This network uses two ``BasicLSTMCell``s combined with ``MultiRNNCell``. The
network is trained to memorize the first few digits of pi.

First, train the LSTM network with Tensorflow:

```
python demos/lstm/train.py
```

Next, export the weights to be used by deeplearn.js:

```
python scripts/dump_checkpoint_vars.py --output_dir=demos/lstm/ --checkpoint_file=/tmp/simple_lstm-1000 --remove_variables_regex=".*Adam.*|.*beta.*"
```

Finally, start the demo:

```
scripts/watch-demo demos/lstm/lstm.ts
```
1 change: 1 addition & 0 deletions demos/lstm/fully_connected_biases
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
,:M�P�>Ƴ�mZK�fN<>��ֻ�G��d���6��H�
Binary file added demos/lstm/fully_connected_weights
Binary file not shown.
25 changes: 25 additions & 0 deletions demos/lstm/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<!-- Copyright 2017 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================-->
<html>
<head>
<title>LSTM Demo</title>
</head>
<body>
<h1>LSTM demo</h1>
<div>Expected:</div>
<div id="expected"></div>
<div>Results:</div>
<div id="results"></div>
<div id="success"></div>
<script src="bundle.js"></script>
</body>
</html>
79 changes: 79 additions & 0 deletions demos/lstm/lstm_inference.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/* Copyright 2017 Google Inc. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

import {Array1D, Array2D, CheckpointLoader, NDArrayMathGPU, Scalar,
util} from '../deeplearnjs';

// manifest.json lives in the same directory.
const reader = new CheckpointLoader('.');
reader.getAllVariables().then(vars => {
const primerData = 3;
const expected = [1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4];
const math = new NDArrayMathGPU();

const lstmKernel1 = vars[
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'] as Array2D;
const lstmBias1 = vars[
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'] as Array1D;

const lstmKernel2 = vars[
'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'] as Array2D;
const lstmBias2 = vars[
'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'] as Array1D;

const fullyConnectedBiases = vars['fully_connected/biases'] as Array1D;
const fullyConnectedWeights = vars['fully_connected/weights'] as Array2D;

const results: number[] = [];

math.scope((keep, track) => {
const forgetBias = track(Scalar.new(1.0));
const lstm1 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel1,
lstmBias1);
const lstm2 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel2,
lstmBias2);

let c = [track(Array2D.zeros([1, lstmBias1.shape[0] / 4])),
track(Array2D.zeros([1, lstmBias2.shape[0] / 4]))];
let h = [track(Array2D.zeros([1, lstmBias1.shape[0] / 4])),
track(Array2D.zeros([1, lstmBias2.shape[0] / 4]))];

let input = primerData;
for (let i = 0; i < expected.length; i++) {
const onehot = track(Array2D.zeros([1, 10]));
onehot.set(1.0, 0, input);

const output = math.multiRNNCell([lstm1, lstm2], onehot, c, h);

c = output[0];
h = output[1];

const outputH = h[1];
const weightedResult = math.matMul(outputH, fullyConnectedWeights);
const logits = math.add( weightedResult, fullyConnectedBiases);

const result = math.argMax(logits).get();
results.push(result);
input = result;
}
});
document.getElementById('expected').innerHTML = '' + expected;
document.getElementById('results').innerHTML = '' + results;
if(util.arraysEqual(expected, results)) {
document.getElementById('success').innerHTML = 'Success!';
} else {
document.getElementById('success').innerHTML = 'Failure.';
}
});
41 changes: 41 additions & 0 deletions demos/lstm/manifest.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"fully_connected/biases": {
"filename": "fully_connected_biases",
"shape": [
10
]
},
"fully_connected/weights": {
"filename": "fully_connected_weights",
"shape": [
20,
10
]
},
"rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias": {
"filename": "rnn_multi_rnn_cell_cell_0_basic_lstm_cell_bias",
"shape": [
80
]
},
"rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel": {
"filename": "rnn_multi_rnn_cell_cell_0_basic_lstm_cell_kernel",
"shape": [
30,
80
]
},
"rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias": {
"filename": "rnn_multi_rnn_cell_cell_1_basic_lstm_cell_bias",
"shape": [
80
]
},
"rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel": {
"filename": "rnn_multi_rnn_cell_cell_1_basic_lstm_cell_kernel",
"shape": [
40,
80
]
}
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
84 changes: 84 additions & 0 deletions demos/lstm/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains and Evaluates a simple LSTM network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('output_dir', '/tmp/simple_lstm',
'Directory to write checkpoint.')

def main(unused_argv):
data = np.array(
[[3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4]])

tf.reset_default_graph()

x = tf.placeholder(dtype=tf.int32, shape=[1, data.shape[1] - 1])
y = tf.placeholder(dtype=tf.int32, shape=[1, data.shape[1] - 1])

NHIDDEN = 20
NLABELS = 10

lstm1 = tf.contrib.rnn.BasicLSTMCell(NHIDDEN)
lstm2 = tf.contrib.rnn.BasicLSTMCell(NHIDDEN)
lstm = tf.contrib.rnn.MultiRNNCell([lstm1, lstm2])
initial_state = lstm.zero_state(1, tf.float32)

outputs, final_state = tf.nn.dynamic_rnn(
cell=lstm, inputs=tf.one_hot(x, NLABELS), initial_state=initial_state)

logits = tf.contrib.layers.linear(outputs, NLABELS)

softmax_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=y, logits=logits)

predictions = tf.argmax(logits, axis=-1)
loss = tf.reduce_mean(softmax_cross_entropy)

train_op = tf.train.AdamOptimizer().minimize(loss)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

print('Starting training...')
NEPOCH = 1000
for step in range(NEPOCH + 1):
loss_out, _ = sess.run([loss, train_op],
feed_dict={
x: data[:, :-1],
y: data[:, 1:],
})
if step % 100 == 0:
print('Loss at step {}: {}'.format(step, loss_out))

print('Expected data:')
print(data[:, 1:])
print('Results:')
print(sess.run([predictions], feed_dict={x: data[:, :-1]}))

saver = tf.train.Saver()
path = saver.save(sess, FLAGS.output_dir, global_step=step)
print('Saved checkpoint at {}'.format(path))


if __name__ == '__main__':
tf.app.run(main)

99 changes: 99 additions & 0 deletions src/math/math.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ import {Array1D, Array2D, Array3D, Array4D, NDArray, Scalar} from './ndarray';

export type ScopeResult = NDArray[]|NDArray|void;

export interface LSTMCell {
(data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D];
}


export abstract class NDArrayMath {
private ndarrayScopes: NDArray[][] = [];
private activeScope: NDArray[];
Expand Down Expand Up @@ -1127,6 +1132,100 @@ export abstract class NDArrayMath {
x: Array3D, mean: Array3D|Array1D, variance: Array3D|Array1D,
varianceEpsilon: number, scale?: Array3D|Array1D,
offset?: Array3D|Array1D): Array3D;

//////////////
// LSTM ops //
//////////////

/**
* Computes the next states and outputs of a stack of LSTMCells.
* Each cell output is used as input to the next cell.
* This is only the forward mode.
* Derived from tf.contrib.rn.MultiRNNCell.
* @param lstmCells Array of LSTMCell functions.
* @param data The input to the cell.
* @param c Array of previous cell states.
* @param h Array of previous cell outputs.
* @return Tuple [nextCellStates, cellOutputs]
*/
multiRNNCell(lstmCells: LSTMCell[], data: Array2D, c: Array2D[],
h: Array2D[]): [Array2D[], Array2D[]] {
util.assert(
data.shape[0] === 1,
`Error in multiRNNCell: first dimension of data is ${data.shape[0]}, ` +
`but batch sizes > 1 are not yet supported.`);
const res = this.scope(() => {
let input = data;
const newStates = [];
for (let i = 0; i < lstmCells.length; i++) {
const output = lstmCells[i](input, c[i], h[i]);
newStates.push(output[0]);
newStates.push(output[1]);
input = output[1];
}

return newStates;
});
const newC: Array2D[] = [];
const newH: Array2D[] = [];
for (let i = 0; i < res.length; i += 2) {
newC.push(res[i] as Array2D);
newH.push(res[i + 1] as Array2D);
}
return [newC, newH];
}

/**
* Computes the next state and output of a BasicLSTMCell.
* This is only the forward mode.
* Derived from tf.contrib.rnn.BasicLSTMCell.
* @param forgetBias Forget bias for the cell.
* @param lstmKernel The weights for the cell.
* @param lstmBias The biases for the cell.
* @param data The input to the cell.
* @param c Previous cell state.
* @param h Previous cell output.
* @return Tuple [nextCellState, cellOutput]
*/
basicLSTMCell(forgetBias: Scalar, lstmKernel: Array2D, lstmBias: Array1D,
data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D] {
const res = this.scope(() => {
util.assert(
data.shape[0] === 1,
`Error in multiRNNCell: first dimension of data is ` +
`${data.shape[0]}, but batch sizes > 1 are not yet supported.`);
// concat(inputs, h, 1)
// There is no concat1d, so reshape inputs and h to 3d, concat, then
// reshape back to 2d.
const data3D = data.as3D(1, 1, data.shape[1]);
const h3D = h.as3D(1, 1, h.shape[1]);
const combined3D = this.concat3D(data3D, h3D, 2);
const combined2D = combined3D.as2D(1, data.shape[1] + h.shape[1]);

const weighted = this.matMul(combined2D, lstmKernel);
const res = this.add(weighted, lstmBias) as Array2D;

// i = input_gate, j = new_input, f = forget_gate, o = output_gate
const i = this.slice2D(res, [0, 0], [res.shape[0], res.shape[1] / 4]);
const j = this.slice2D(res, [0, res.shape[1] / 4 * 1],
[res.shape[0], res.shape[1] / 4]);
const f = this.slice2D(res, [0, res.shape[1] / 4 * 2],
[res.shape[0], res.shape[1] / 4]);
const o = this.slice2D(res, [0, res.shape[1] / 4 * 3],
[res.shape[0], res.shape[1] / 4]);

const newC = this.add(
this.multiplyStrict(c,
this.sigmoid(this.scalarPlusArray(forgetBias, f))),
this.multiplyStrict(this.sigmoid(i), this.tanh(j))) as Array2D;
const newH = this.multiplyStrict(
this.tanh(newC), this.sigmoid(o)) as Array2D;

return [newC, newH];
});
return [res[0], res[1]];
}

}

export enum MatrixOrientation {
Expand Down
Loading