Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tfjs-core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, Ten
export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType} from './types';

export * from './ops/ops';
export {LSTMCellFunc} from './ops/lstm';
export {Reduction} from './ops/loss_ops';

export * from './train';
Expand Down
78 changes: 78 additions & 0 deletions tfjs-core/src/ops/basic_lstm_cell.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Scalar, Tensor1D, Tensor2D} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';

import {add} from './add';
import {concat} from './concat';
import {matMul} from './mat_mul';
import {mul} from './mul';
import {op} from './operation';
import {slice} from './slice';
import {sigmoid, tanh} from './unary_ops';

/**
* Computes the next state and output of a BasicLSTMCell.
*
* Returns `[newC, newH]`.
*
* Derived from tf.contrib.rnn.BasicLSTMCell.
*
* @param forgetBias Forget bias for the cell.
* @param lstmKernel The weights for the cell.
* @param lstmBias The bias for the cell.
* @param data The input to the cell.
* @param c Previous cell state.
* @param h Previous cell output.
*/
/** @doc {heading: 'Operations', subheading: 'RNN'} */
function basicLSTMCell_(
forgetBias: Scalar|TensorLike, lstmKernel: Tensor2D|TensorLike,
lstmBias: Tensor1D|TensorLike, data: Tensor2D|TensorLike,
c: Tensor2D|TensorLike, h: Tensor2D|TensorLike): [Tensor2D, Tensor2D] {
const $forgetBias =
convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell');
const $lstmKernel =
convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell');
const $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell');
const $data = convertToTensor(data, 'data', 'basicLSTMCell');
const $c = convertToTensor(c, 'c', 'basicLSTMCell');
const $h = convertToTensor(h, 'h', 'basicLSTMCell');

const combined = concat([$data, $h], 1);
const weighted = matMul(combined, $lstmKernel);
const res: Tensor2D = add(weighted, $lstmBias);

// i = input_gate, j = new_input, f = forget_gate, o = output_gate
const batchSize = res.shape[0];
const sliceCols = res.shape[1] / 4;
const sliceSize: [number, number] = [batchSize, sliceCols];
const i = slice(res, [0, 0], sliceSize);
const j = slice(res, [0, sliceCols], sliceSize);
const f = slice(res, [0, sliceCols * 2], sliceSize);
const o = slice(res, [0, sliceCols * 3], sliceSize);

const newC: Tensor2D =
add(mul(sigmoid(i), tanh(j)),
mul($c, sigmoid(add($forgetBias, f)) as Tensor2D));
const newH: Tensor2D = mul(tanh(newC), sigmoid(o));
return [newC, newH];
}

export const basicLSTMCell = op({basicLSTMCell_});
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand All @@ -17,57 +17,9 @@

import * as tf from '../index';
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
import {Tensor2D} from '../tensor';
import {expectArraysClose} from '../test_util';
import {Rank} from '../types';

describeWithFlags('lstm', ALL_ENVS, () => {
it('MultiRNNCell with 2 BasicLSTMCells', async () => {
const lstmKernel1 = tf.tensor2d(
[
0.26242125034332275, -0.8787832260131836, 0.781475305557251,
1.337337851524353, 0.6180247068405151, -0.2760246992111206,
-0.11299663782119751, -0.46332040429115295, -0.1765323281288147,
0.6807947158813477, -0.8326982855796814, 0.6732975244522095
],
[3, 4]);
const lstmBias1 = tf.tensor1d(
[1.090713620185852, -0.8282332420349121, 0, 1.0889357328414917]);
const lstmKernel2 = tf.tensor2d(
[
-1.893059492111206, -1.0185645818710327, -0.6270437240600586,
-2.1829540729522705, -0.4583775997161865, -0.5454602241516113,
-0.3114445209503174, 0.8450229167938232
],
[2, 4]);
const lstmBias2 = tf.tensor1d(
[0.9906240105628967, 0.6248329877853394, 0, 1.0224634408950806]);

const forgetBias = tf.scalar(1.0);
const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) =>
tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) =>
tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);
const c = [
tf.zeros<Rank.R2>([1, lstmBias1.shape[0] / 4]),
tf.zeros<Rank.R2>([1, lstmBias2.shape[0] / 4])
];
const h = [
tf.zeros<Rank.R2>([1, lstmBias1.shape[0] / 4]),
tf.zeros<Rank.R2>([1, lstmBias2.shape[0] / 4])
];

const onehot = tf.buffer<Rank.R2>([1, 2], 'float32');
onehot.set(1.0, 0, 0);

const output = tf.multiRNNCell([lstm1, lstm2], onehot.toTensor(), c, h);

expectArraysClose(await output[0][0].data(), [-0.7440074682235718]);
expectArraysClose(await output[0][1].data(), [0.7460772395133972]);
expectArraysClose(await output[1][0].data(), [-0.5802832245826721]);
expectArraysClose(await output[1][1].data(), [0.5745711922645569]);
});

describeWithFlags('basicLSTMCell', ALL_ENVS, () => {
it('basicLSTMCell with batch=2', async () => {
const lstmKernel = tf.randomNormal<Rank.R2>([3, 4]);
const lstmBias = tf.randomNormal<Rank.R1>([4]);
Expand Down Expand Up @@ -106,79 +58,6 @@ describeWithFlags('lstm', ALL_ENVS, () => {
expect(newHVals[0][0]).toEqual(newHVals[1][0]);
});
});

describeWithFlags('multiRNN throws when passed non-tensor', ALL_ENVS, () => {
it('input: data', () => {
const lstmKernel1: tf.Tensor2D = tf.zeros([3, 4]);
const lstmBias1: tf.Tensor1D = tf.zeros([4]);
const lstmKernel2: tf.Tensor2D = tf.zeros([2, 4]);
const lstmBias2: tf.Tensor1D = tf.zeros([4]);

const forgetBias = tf.scalar(1.0);
const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) =>
tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) =>
tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);
const c = [
tf.zeros<Rank.R2>([1, lstmBias1.shape[0] / 4]),
tf.zeros<Rank.R2>([1, lstmBias2.shape[0] / 4])
];
const h = [
tf.zeros<Rank.R2>([1, lstmBias1.shape[0] / 4]),
tf.zeros<Rank.R2>([1, lstmBias2.shape[0] / 4])
];

expect(() => tf.multiRNNCell([lstm1, lstm2], {} as tf.Tensor2D, c, h))
.toThrowError(
/Argument 'data' passed to 'multiRNNCell' must be a Tensor/);
});

it('input: c', () => {
const lstmKernel1: tf.Tensor2D = tf.zeros([3, 4]);
const lstmBias1: tf.Tensor1D = tf.zeros([4]);
const lstmKernel2: tf.Tensor2D = tf.zeros([2, 4]);
const lstmBias2: tf.Tensor1D = tf.zeros([4]);

const forgetBias = tf.scalar(1.0);
const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) =>
tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) =>
tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);

const h = [
tf.zeros<Rank.R2>([1, lstmBias1.shape[0] / 4]),
tf.zeros<Rank.R2>([1, lstmBias2.shape[0] / 4])
];
const data: tf.Tensor2D = tf.zeros([1, 2]);

expect(() => tf.multiRNNCell([lstm1, lstm2], data, [{} as tf.Tensor2D], h))
.toThrowError(
/Argument 'c\[0\]' passed to 'multiRNNCell' must be a Tensor/);
});

it('input: h', () => {
const lstmKernel1: tf.Tensor2D = tf.zeros([3, 4]);
const lstmBias1: tf.Tensor1D = tf.zeros([4]);
const lstmKernel2: tf.Tensor2D = tf.zeros([2, 4]);
const lstmBias2: tf.Tensor1D = tf.zeros([4]);

const forgetBias = tf.scalar(1.0);
const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) =>
tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h);
const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) =>
tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h);
const c = [
tf.zeros<Rank.R2>([1, lstmBias1.shape[0] / 4]),
tf.zeros<Rank.R2>([1, lstmBias2.shape[0] / 4])
];
const data: tf.Tensor2D = tf.zeros([1, 2]);

expect(() => tf.multiRNNCell([lstm1, lstm2], data, c, [{} as tf.Tensor2D]))
.toThrowError(
/Argument 'h\[0\]' passed to 'multiRNNCell' must be a Tensor/);
});
});

describeWithFlags('basicLSTMCell throws with non-tensor', ALL_ENVS, () => {
it('input: forgetBias', () => {
const lstmKernel = tf.randomNormal<Rank.R2>([3, 4]);
Expand Down
54 changes: 2 additions & 52 deletions tfjs-core/src/ops/lstm.ts → tfjs-core/src/ops/multi_rnn_cell.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand All @@ -14,8 +14,7 @@
* limitations under the License.
* =============================================================================
*/

import {Scalar, Tensor1D, Tensor2D} from '../tensor';
import {Tensor2D} from '../tensor';
import {convertToTensor, convertToTensorArray} from '../tensor_util_env';
import {TensorLike} from '../types';
import {op} from './operation';
Expand Down Expand Up @@ -66,53 +65,4 @@ function multiRNNCell_(
}
return [newC, newH];
}

/**
* Computes the next state and output of a BasicLSTMCell.
*
* Returns `[newC, newH]`.
*
* Derived from tf.contrib.rnn.BasicLSTMCell.
*
* @param forgetBias Forget bias for the cell.
* @param lstmKernel The weights for the cell.
* @param lstmBias The bias for the cell.
* @param data The input to the cell.
* @param c Previous cell state.
* @param h Previous cell output.
*/
/** @doc {heading: 'Operations', subheading: 'RNN'} */
function basicLSTMCell_(
forgetBias: Scalar|TensorLike, lstmKernel: Tensor2D|TensorLike,
lstmBias: Tensor1D|TensorLike, data: Tensor2D|TensorLike,
c: Tensor2D|TensorLike, h: Tensor2D|TensorLike): [Tensor2D, Tensor2D] {
const $forgetBias =
convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell');
const $lstmKernel =
convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell');
const $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell');
const $data = convertToTensor(data, 'data', 'basicLSTMCell');
const $c = convertToTensor(c, 'c', 'basicLSTMCell');
const $h = convertToTensor(h, 'h', 'basicLSTMCell');

const combined = $data.concat($h, 1);
const weighted = combined.matMul($lstmKernel);
const res: Tensor2D = weighted.add($lstmBias);

// i = input_gate, j = new_input, f = forget_gate, o = output_gate
const batchSize = res.shape[0];
const sliceCols = res.shape[1] / 4;
const sliceSize: [number, number] = [batchSize, sliceCols];
const i = res.slice([0, 0], sliceSize);
const j = res.slice([0, sliceCols], sliceSize);
const f = res.slice([0, sliceCols * 2], sliceSize);
const o = res.slice([0, sliceCols * 3], sliceSize);

const newC: Tensor2D = i.sigmoid().mul(j.tanh()).add(
$c.mul($forgetBias.add(f).sigmoid() as Tensor2D));
const newH: Tensor2D = newC.tanh().mul(o.sigmoid());
return [newC, newH];
}

export const basicLSTMCell = op({basicLSTMCell_});
export const multiRNNCell = op({multiRNNCell_});
Loading