From ae77da2ac92c80c65c74e8566c42794f2ef8244c Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 23 Jun 2020 11:48:06 -0400 Subject: [PATCH] modularize lstm ops --- tfjs-core/src/index.ts | 1 - tfjs-core/src/ops/basic_lstm_cell.ts | 78 ++++++++++ .../{lstm_test.ts => basic_lstm_cell_test.ts} | 125 +-------------- .../src/ops/{lstm.ts => multi_rnn_cell.ts} | 54 +------ tfjs-core/src/ops/multi_rnn_cell_test.ts | 142 ++++++++++++++++++ tfjs-core/src/ops/ops.ts | 3 +- tfjs-core/src/tests.ts | 3 +- 7 files changed, 228 insertions(+), 178 deletions(-) create mode 100644 tfjs-core/src/ops/basic_lstm_cell.ts rename tfjs-core/src/ops/{lstm_test.ts => basic_lstm_cell_test.ts} (56%) rename tfjs-core/src/ops/{lstm.ts => multi_rnn_cell.ts} (53%) create mode 100644 tfjs-core/src/ops/multi_rnn_cell_test.ts diff --git a/tfjs-core/src/index.ts b/tfjs-core/src/index.ts index 72432d68fe5..27d710984bc 100644 --- a/tfjs-core/src/index.ts +++ b/tfjs-core/src/index.ts @@ -64,7 +64,6 @@ export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, Ten export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType} from './types'; export * from './ops/ops'; -export {LSTMCellFunc} from './ops/lstm'; export {Reduction} from './ops/loss_ops'; export * from './train'; diff --git a/tfjs-core/src/ops/basic_lstm_cell.ts b/tfjs-core/src/ops/basic_lstm_cell.ts new file mode 100644 index 00000000000..fba07bc50e4 --- /dev/null +++ b/tfjs-core/src/ops/basic_lstm_cell.ts @@ -0,0 +1,78 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Scalar, Tensor1D, Tensor2D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {add} from './add'; +import {concat} from './concat'; +import {matMul} from './mat_mul'; +import {mul} from './mul'; +import {op} from './operation'; +import {slice} from './slice'; +import {sigmoid, tanh} from './unary_ops'; + +/** + * Computes the next state and output of a BasicLSTMCell. + * + * Returns `[newC, newH]`. + * + * Derived from tf.contrib.rnn.BasicLSTMCell. + * + * @param forgetBias Forget bias for the cell. + * @param lstmKernel The weights for the cell. + * @param lstmBias The bias for the cell. + * @param data The input to the cell. + * @param c Previous cell state. + * @param h Previous cell output. + */ +/** @doc {heading: 'Operations', subheading: 'RNN'} */ +function basicLSTMCell_( + forgetBias: Scalar|TensorLike, lstmKernel: Tensor2D|TensorLike, + lstmBias: Tensor1D|TensorLike, data: Tensor2D|TensorLike, + c: Tensor2D|TensorLike, h: Tensor2D|TensorLike): [Tensor2D, Tensor2D] { + const $forgetBias = + convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell'); + const $lstmKernel = + convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell'); + const $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell'); + const $data = convertToTensor(data, 'data', 'basicLSTMCell'); + const $c = convertToTensor(c, 'c', 'basicLSTMCell'); + const $h = convertToTensor(h, 'h', 'basicLSTMCell'); + + const combined = concat([$data, $h], 1); + const weighted = matMul(combined, $lstmKernel); + const res: Tensor2D = add(weighted, $lstmBias); + + // i = input_gate, j = new_input, f = forget_gate, o = output_gate + const batchSize = res.shape[0]; + const sliceCols = res.shape[1] / 4; + const sliceSize: [number, number] = [batchSize, sliceCols]; + const i = slice(res, [0, 0], sliceSize); + const j = slice(res, [0, sliceCols], sliceSize); + const f = slice(res, [0, sliceCols * 2], sliceSize); + const o = slice(res, [0, sliceCols * 3], sliceSize); + + const newC: Tensor2D = + add(mul(sigmoid(i), tanh(j)), + mul($c, sigmoid(add($forgetBias, f)) as Tensor2D)); + const newH: Tensor2D = mul(tanh(newC), sigmoid(o)); + return [newC, newH]; +} + +export const basicLSTMCell = op({basicLSTMCell_}); diff --git a/tfjs-core/src/ops/lstm_test.ts b/tfjs-core/src/ops/basic_lstm_cell_test.ts similarity index 56% rename from tfjs-core/src/ops/lstm_test.ts rename to tfjs-core/src/ops/basic_lstm_cell_test.ts index aa014bd6dfa..0d19d515c36 100644 --- a/tfjs-core/src/ops/lstm_test.ts +++ b/tfjs-core/src/ops/basic_lstm_cell_test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -17,57 +17,9 @@ import * as tf from '../index'; import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; -import {Tensor2D} from '../tensor'; -import {expectArraysClose} from '../test_util'; import {Rank} from '../types'; -describeWithFlags('lstm', ALL_ENVS, () => { - it('MultiRNNCell with 2 BasicLSTMCells', async () => { - const lstmKernel1 = tf.tensor2d( - [ - 0.26242125034332275, -0.8787832260131836, 0.781475305557251, - 1.337337851524353, 0.6180247068405151, -0.2760246992111206, - -0.11299663782119751, -0.46332040429115295, -0.1765323281288147, - 0.6807947158813477, -0.8326982855796814, 0.6732975244522095 - ], - [3, 4]); - const lstmBias1 = tf.tensor1d( - [1.090713620185852, -0.8282332420349121, 0, 1.0889357328414917]); - const lstmKernel2 = tf.tensor2d( - [ - -1.893059492111206, -1.0185645818710327, -0.6270437240600586, - -2.1829540729522705, -0.4583775997161865, -0.5454602241516113, - -0.3114445209503174, 0.8450229167938232 - ], - [2, 4]); - const lstmBias2 = tf.tensor1d( - [0.9906240105628967, 0.6248329877853394, 0, 1.0224634408950806]); - - const forgetBias = tf.scalar(1.0); - const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => - tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h); - const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => - tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h); - const c = [ - tf.zeros([1, lstmBias1.shape[0] / 4]), - tf.zeros([1, lstmBias2.shape[0] / 4]) - ]; - const h = [ - tf.zeros([1, lstmBias1.shape[0] / 4]), - tf.zeros([1, lstmBias2.shape[0] / 4]) - ]; - - const onehot = tf.buffer([1, 2], 'float32'); - onehot.set(1.0, 0, 0); - - const output = tf.multiRNNCell([lstm1, lstm2], onehot.toTensor(), c, h); - - expectArraysClose(await output[0][0].data(), [-0.7440074682235718]); - expectArraysClose(await output[0][1].data(), [0.7460772395133972]); - expectArraysClose(await output[1][0].data(), [-0.5802832245826721]); - expectArraysClose(await output[1][1].data(), [0.5745711922645569]); - }); - +describeWithFlags('basicLSTMCell', ALL_ENVS, () => { it('basicLSTMCell with batch=2', async () => { const lstmKernel = tf.randomNormal([3, 4]); const lstmBias = tf.randomNormal([4]); @@ -106,79 +58,6 @@ describeWithFlags('lstm', ALL_ENVS, () => { expect(newHVals[0][0]).toEqual(newHVals[1][0]); }); }); - -describeWithFlags('multiRNN throws when passed non-tensor', ALL_ENVS, () => { - it('input: data', () => { - const lstmKernel1: tf.Tensor2D = tf.zeros([3, 4]); - const lstmBias1: tf.Tensor1D = tf.zeros([4]); - const lstmKernel2: tf.Tensor2D = tf.zeros([2, 4]); - const lstmBias2: tf.Tensor1D = tf.zeros([4]); - - const forgetBias = tf.scalar(1.0); - const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => - tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h); - const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => - tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h); - const c = [ - tf.zeros([1, lstmBias1.shape[0] / 4]), - tf.zeros([1, lstmBias2.shape[0] / 4]) - ]; - const h = [ - tf.zeros([1, lstmBias1.shape[0] / 4]), - tf.zeros([1, lstmBias2.shape[0] / 4]) - ]; - - expect(() => tf.multiRNNCell([lstm1, lstm2], {} as tf.Tensor2D, c, h)) - .toThrowError( - /Argument 'data' passed to 'multiRNNCell' must be a Tensor/); - }); - - it('input: c', () => { - const lstmKernel1: tf.Tensor2D = tf.zeros([3, 4]); - const lstmBias1: tf.Tensor1D = tf.zeros([4]); - const lstmKernel2: tf.Tensor2D = tf.zeros([2, 4]); - const lstmBias2: tf.Tensor1D = tf.zeros([4]); - - const forgetBias = tf.scalar(1.0); - const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => - tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h); - const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => - tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h); - - const h = [ - tf.zeros([1, lstmBias1.shape[0] / 4]), - tf.zeros([1, lstmBias2.shape[0] / 4]) - ]; - const data: tf.Tensor2D = tf.zeros([1, 2]); - - expect(() => tf.multiRNNCell([lstm1, lstm2], data, [{} as tf.Tensor2D], h)) - .toThrowError( - /Argument 'c\[0\]' passed to 'multiRNNCell' must be a Tensor/); - }); - - it('input: h', () => { - const lstmKernel1: tf.Tensor2D = tf.zeros([3, 4]); - const lstmBias1: tf.Tensor1D = tf.zeros([4]); - const lstmKernel2: tf.Tensor2D = tf.zeros([2, 4]); - const lstmBias2: tf.Tensor1D = tf.zeros([4]); - - const forgetBias = tf.scalar(1.0); - const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => - tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h); - const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => - tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h); - const c = [ - tf.zeros([1, lstmBias1.shape[0] / 4]), - tf.zeros([1, lstmBias2.shape[0] / 4]) - ]; - const data: tf.Tensor2D = tf.zeros([1, 2]); - - expect(() => tf.multiRNNCell([lstm1, lstm2], data, c, [{} as tf.Tensor2D])) - .toThrowError( - /Argument 'h\[0\]' passed to 'multiRNNCell' must be a Tensor/); - }); -}); - describeWithFlags('basicLSTMCell throws with non-tensor', ALL_ENVS, () => { it('input: forgetBias', () => { const lstmKernel = tf.randomNormal([3, 4]); diff --git a/tfjs-core/src/ops/lstm.ts b/tfjs-core/src/ops/multi_rnn_cell.ts similarity index 53% rename from tfjs-core/src/ops/lstm.ts rename to tfjs-core/src/ops/multi_rnn_cell.ts index 7f43c500100..a57f7bab7d6 100644 --- a/tfjs-core/src/ops/lstm.ts +++ b/tfjs-core/src/ops/multi_rnn_cell.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -14,8 +14,7 @@ * limitations under the License. * ============================================================================= */ - -import {Scalar, Tensor1D, Tensor2D} from '../tensor'; +import {Tensor2D} from '../tensor'; import {convertToTensor, convertToTensorArray} from '../tensor_util_env'; import {TensorLike} from '../types'; import {op} from './operation'; @@ -66,53 +65,4 @@ function multiRNNCell_( } return [newC, newH]; } - -/** - * Computes the next state and output of a BasicLSTMCell. - * - * Returns `[newC, newH]`. - * - * Derived from tf.contrib.rnn.BasicLSTMCell. - * - * @param forgetBias Forget bias for the cell. - * @param lstmKernel The weights for the cell. - * @param lstmBias The bias for the cell. - * @param data The input to the cell. - * @param c Previous cell state. - * @param h Previous cell output. - */ -/** @doc {heading: 'Operations', subheading: 'RNN'} */ -function basicLSTMCell_( - forgetBias: Scalar|TensorLike, lstmKernel: Tensor2D|TensorLike, - lstmBias: Tensor1D|TensorLike, data: Tensor2D|TensorLike, - c: Tensor2D|TensorLike, h: Tensor2D|TensorLike): [Tensor2D, Tensor2D] { - const $forgetBias = - convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell'); - const $lstmKernel = - convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell'); - const $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell'); - const $data = convertToTensor(data, 'data', 'basicLSTMCell'); - const $c = convertToTensor(c, 'c', 'basicLSTMCell'); - const $h = convertToTensor(h, 'h', 'basicLSTMCell'); - - const combined = $data.concat($h, 1); - const weighted = combined.matMul($lstmKernel); - const res: Tensor2D = weighted.add($lstmBias); - - // i = input_gate, j = new_input, f = forget_gate, o = output_gate - const batchSize = res.shape[0]; - const sliceCols = res.shape[1] / 4; - const sliceSize: [number, number] = [batchSize, sliceCols]; - const i = res.slice([0, 0], sliceSize); - const j = res.slice([0, sliceCols], sliceSize); - const f = res.slice([0, sliceCols * 2], sliceSize); - const o = res.slice([0, sliceCols * 3], sliceSize); - - const newC: Tensor2D = i.sigmoid().mul(j.tanh()).add( - $c.mul($forgetBias.add(f).sigmoid() as Tensor2D)); - const newH: Tensor2D = newC.tanh().mul(o.sigmoid()); - return [newC, newH]; -} - -export const basicLSTMCell = op({basicLSTMCell_}); export const multiRNNCell = op({multiRNNCell_}); diff --git a/tfjs-core/src/ops/multi_rnn_cell_test.ts b/tfjs-core/src/ops/multi_rnn_cell_test.ts new file mode 100644 index 00000000000..689519108e5 --- /dev/null +++ b/tfjs-core/src/ops/multi_rnn_cell_test.ts @@ -0,0 +1,142 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {Tensor2D} from '../tensor'; +import {expectArraysClose} from '../test_util'; +import {Rank} from '../types'; + +describeWithFlags('lstm', ALL_ENVS, () => { + it('MultiRNNCell with 2 BasicLSTMCells', async () => { + const lstmKernel1 = tf.tensor2d( + [ + 0.26242125034332275, -0.8787832260131836, 0.781475305557251, + 1.337337851524353, 0.6180247068405151, -0.2760246992111206, + -0.11299663782119751, -0.46332040429115295, -0.1765323281288147, + 0.6807947158813477, -0.8326982855796814, 0.6732975244522095 + ], + [3, 4]); + const lstmBias1 = tf.tensor1d( + [1.090713620185852, -0.8282332420349121, 0, 1.0889357328414917]); + const lstmKernel2 = tf.tensor2d( + [ + -1.893059492111206, -1.0185645818710327, -0.6270437240600586, + -2.1829540729522705, -0.4583775997161865, -0.5454602241516113, + -0.3114445209503174, 0.8450229167938232 + ], + [2, 4]); + const lstmBias2 = tf.tensor1d( + [0.9906240105628967, 0.6248329877853394, 0, 1.0224634408950806]); + + const forgetBias = tf.scalar(1.0); + const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => + tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h); + const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => + tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h); + const c = [ + tf.zeros([1, lstmBias1.shape[0] / 4]), + tf.zeros([1, lstmBias2.shape[0] / 4]) + ]; + const h = [ + tf.zeros([1, lstmBias1.shape[0] / 4]), + tf.zeros([1, lstmBias2.shape[0] / 4]) + ]; + + const onehot = tf.buffer([1, 2], 'float32'); + onehot.set(1.0, 0, 0); + + const output = tf.multiRNNCell([lstm1, lstm2], onehot.toTensor(), c, h); + + expectArraysClose(await output[0][0].data(), [-0.7440074682235718]); + expectArraysClose(await output[0][1].data(), [0.7460772395133972]); + expectArraysClose(await output[1][0].data(), [-0.5802832245826721]); + expectArraysClose(await output[1][1].data(), [0.5745711922645569]); + }); +}); + +describeWithFlags('multiRNN throws when passed non-tensor', ALL_ENVS, () => { + it('input: data', () => { + const lstmKernel1: tf.Tensor2D = tf.zeros([3, 4]); + const lstmBias1: tf.Tensor1D = tf.zeros([4]); + const lstmKernel2: tf.Tensor2D = tf.zeros([2, 4]); + const lstmBias2: tf.Tensor1D = tf.zeros([4]); + + const forgetBias = tf.scalar(1.0); + const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => + tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h); + const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => + tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h); + const c = [ + tf.zeros([1, lstmBias1.shape[0] / 4]), + tf.zeros([1, lstmBias2.shape[0] / 4]) + ]; + const h = [ + tf.zeros([1, lstmBias1.shape[0] / 4]), + tf.zeros([1, lstmBias2.shape[0] / 4]) + ]; + + expect(() => tf.multiRNNCell([lstm1, lstm2], {} as tf.Tensor2D, c, h)) + .toThrowError( + /Argument 'data' passed to 'multiRNNCell' must be a Tensor/); + }); + + it('input: c', () => { + const lstmKernel1: tf.Tensor2D = tf.zeros([3, 4]); + const lstmBias1: tf.Tensor1D = tf.zeros([4]); + const lstmKernel2: tf.Tensor2D = tf.zeros([2, 4]); + const lstmBias2: tf.Tensor1D = tf.zeros([4]); + + const forgetBias = tf.scalar(1.0); + const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => + tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h); + const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => + tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h); + + const h = [ + tf.zeros([1, lstmBias1.shape[0] / 4]), + tf.zeros([1, lstmBias2.shape[0] / 4]) + ]; + const data: tf.Tensor2D = tf.zeros([1, 2]); + + expect(() => tf.multiRNNCell([lstm1, lstm2], data, [{} as tf.Tensor2D], h)) + .toThrowError( + /Argument 'c\[0\]' passed to 'multiRNNCell' must be a Tensor/); + }); + + it('input: h', () => { + const lstmKernel1: tf.Tensor2D = tf.zeros([3, 4]); + const lstmBias1: tf.Tensor1D = tf.zeros([4]); + const lstmKernel2: tf.Tensor2D = tf.zeros([2, 4]); + const lstmBias2: tf.Tensor1D = tf.zeros([4]); + + const forgetBias = tf.scalar(1.0); + const lstm1 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => + tf.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h); + const lstm2 = (data: Tensor2D, c: Tensor2D, h: Tensor2D) => + tf.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h); + const c = [ + tf.zeros([1, lstmBias1.shape[0] / 4]), + tf.zeros([1, lstmBias2.shape[0] / 4]) + ]; + const data: tf.Tensor2D = tf.zeros([1, 2]); + + expect(() => tf.multiRNNCell([lstm1, lstm2], data, c, [{} as tf.Tensor2D])) + .toThrowError( + /Argument 'h\[0\]' passed to 'multiRNNCell' must be a Tensor/); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 5a6c48c6d8c..81c1166dc94 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -21,6 +21,7 @@ export {addN} from './add_n'; export {atan2} from './atan2'; export {avgPool} from './avg_pool'; export {avgPool3d} from './avg_pool_3d'; +export {basicLSTMCell} from './basic_lstm_cell'; export {batchToSpaceND} from './batch_to_space_nd'; export {batchNorm} from './batchnorm'; export {batchNorm2d} from './batchnorm2d'; @@ -72,6 +73,7 @@ export {maximum} from './maximum'; export {minimum} from './minimum'; export {mod} from './mod'; export {mul} from './mul'; +export {LSTMCellFunc, multiRNNCell} from './multi_rnn_cell'; export {multinomial} from './multinomial'; export {notEqual} from './not_equal'; export {oneHot} from './one_hot'; @@ -116,7 +118,6 @@ export * from './transpose'; export * from './softmax'; export * from './norm'; export * from './segment_ops'; -export * from './lstm'; export * from './moving_average'; export * from './strided_slice'; export * from './topk'; diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index f5a76191f1a..43a1524980f 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -46,6 +46,7 @@ import './ops/avg_pool_3d_test'; import './ops/avg_pool_test'; import './ops/axis_util_test'; import './ops/band_part_test'; +import './ops/basic_lstm_cell_test'; import './ops/batch_to_space_nd_test'; import './ops/batchnorm_test'; import './ops/binary_ops_test'; @@ -91,12 +92,12 @@ import './ops/logical_not_test'; import './ops/logical_or_test'; import './ops/logical_xor_test'; import './ops/loss_ops_test'; -import './ops/lstm_test'; import './ops/mat_mul_test'; import './ops/max_pool_3d_test'; import './ops/max_pool_test'; import './ops/max_pool_with_argmax_test'; import './ops/moving_average_test'; +import './ops/multi_rnn_cell_test'; import './ops/multinomial_test'; import './ops/non_max_suppression_async_test'; import './ops/non_max_suppression_test';