diff --git a/tfjs-layers/src/layers/convolutional_recurrent.ts b/tfjs-layers/src/layers/convolutional_recurrent.ts index 5b0ff248b75..e5edf0feb2c 100644 --- a/tfjs-layers/src/layers/convolutional_recurrent.ts +++ b/tfjs-layers/src/layers/convolutional_recurrent.ts @@ -287,15 +287,6 @@ class ConvRNN2D extends RNN { }); } - getConfig(): tfc.serialization.ConfigDict { - const {'cell': _, ...config} = super.getConfig(); - - const cellConfig = this.cell.getConfig(); - - // this order is necessary, to prevent cell name from replacing layer name - return {...cellConfig, ...config}; - } - protected computeSingleOutputShape(inputShape: Shape): Shape { const {dataFormat, filters, kernelSize, padding, strides, dilationRate} = this.cell; @@ -532,6 +523,7 @@ export class ConvLSTM2DCell extends LSTMCell implements ConvRNN2DCell { padding: this.padding, dataFormat: this.dataFormat, dilationRate: this.dilationRate, + strides: this.strides, }; return {...baseConfig, ...config}; diff --git a/tfjs-layers/src/layers/convolutional_recurrent_test.ts b/tfjs-layers/src/layers/convolutional_recurrent_test.ts index 9e1026d336f..55cb3d56629 100644 --- a/tfjs-layers/src/layers/convolutional_recurrent_test.ts +++ b/tfjs-layers/src/layers/convolutional_recurrent_test.ts @@ -10,16 +10,20 @@ */ import * as tfc from '@tensorflow/tfjs-core'; +import {serializeActivation, Tanh} from '../activations'; +import {NonNeg, serializeConstraint, UnitNorm} from '../constraints'; import {sequential} from '../exports'; import * as tfl from '../index'; +import {GlorotUniform, HeUniform, Ones, serializeInitializer} from '../initializers'; import {DataFormat, PaddingMode} from '../keras_format/common'; import {modelFromJSON} from '../models'; +import {L1L2, serializeRegularizer} from '../regularizers'; import {getCartesianProductOfValues} from '../utils/generic_utils'; import {convertPythonicToTs, convertTsToPythonic} from '../utils/serialization_utils'; import {describeMathCPU, describeMathCPUAndGPU, expectTensorsClose} from '../utils/test_utils'; -import {ConvLSTM2DArgs} from './convolutional_recurrent'; +import {ConvLSTM2DArgs, ConvLSTM2DCellArgs} from './convolutional_recurrent'; describeMathCPUAndGPU('ConvLSTM2DCell', () => { /** @@ -737,6 +741,112 @@ describeMathCPU('should run BPTT correctly', () => { }); describeMathCPU('ConvLSTM2D Serialization and Deserialization', () => { + const cellConfig: ConvLSTM2DCellArgs = { + filters: 32, + kernelSize: 3, + dataFormat: 'channelsLast', + dilationRate: 2, + padding: 'same', + strides: 2, + activation: 'tanh', + recurrentActivation: 'tanh', + useBias: true, + kernelInitializer: 'glorotUniform', + recurrentInitializer: 'heUniform', + biasInitializer: 'ones', + kernelRegularizer: 'l1l2', + recurrentRegularizer: 'l1l2', + biasRegularizer: 'l1l2', + kernelConstraint: 'unitNorm', + recurrentConstraint: 'unitNorm', + biasConstraint: 'nonNeg', + dropout: 0.1, + recurrentDropout: 0.2, + name: 'cell_1', + batchSize: 12, + batchInputShape: [12, 8, 8], + inputShape: [8, 8], + dtype: 'int32', + inputDType: 'int32', + trainable: true, + implementation: 1, + unitForgetBias: true, + }; + + const expectedCellConfigPrime = { + name: 'cell_1', + trainable: true, + batchInputShape: [12, 8, 8], + dtype: 'int32', + filters: 32, + kernelSize: [3, 3], + dataFormat: 'channelsLast', + dilationRate: [2, 2], + padding: 'same', + strides: [2, 2], + activation: serializeActivation(new Tanh()), + recurrentActivation: serializeActivation(new Tanh()), + useBias: true, + kernelInitializer: serializeInitializer(new GlorotUniform()), + recurrentInitializer: serializeInitializer(new HeUniform()), + biasInitializer: serializeInitializer(new Ones()), + kernelRegularizer: serializeRegularizer(new L1L2()), + recurrentRegularizer: serializeRegularizer(new L1L2()), + biasRegularizer: serializeRegularizer(new L1L2()), + activityRegularizer: serializeRegularizer(null), + kernelConstraint: serializeConstraint(new UnitNorm({})), + recurrentConstraint: serializeConstraint(new UnitNorm({})), + biasConstraint: serializeConstraint(new NonNeg()), + implementation: 1, + unitForgetBias: true, + }; + + describe('ConvLSTM2DCell.getConfig', () => { + it('should return the expected values', () => { + const cell = tfl.layers.convLstm2dCell(cellConfig); + + const {dropout, recurrentDropout, ...configPrime} = cell.getConfig(); + + expect(configPrime).toEqual(expectedCellConfigPrime); + expect(dropout).toBeCloseTo(0.1); + expect(recurrentDropout).toBeCloseTo(0.2); + }); + }); + + describe('ConvLSTM2D.getConfig', () => { + it('should return the expected values', () => { + const config: ConvLSTM2DArgs = { + ...cellConfig, + name: 'layer_1', + ...{ + returnSequences: true, + returnState: true, + stateful: true, + unroll: false, + goBackwards: true, + inputDim: 8, + inputLength: 8, + } as Omit + }; + + const cell = tfl.layers.convLstm2d(config); + + const {dropout, recurrentDropout, ...configPrime} = cell.getConfig(); + + expect(configPrime).toEqual({ + ...expectedCellConfigPrime, + name: 'layer_1', + returnSequences: true, + returnState: true, + stateful: true, + unroll: false, + goBackwards: true, + }); + expect(dropout).toBeCloseTo(0.1); + expect(recurrentDropout).toBeCloseTo(0.2); + }); + }); + it('should return equal outputs before and after', async () => { const model = sequential(); diff --git a/tfjs-layers/src/layers/recurrent.ts b/tfjs-layers/src/layers/recurrent.ts index 41db82807a6..e9ea508d849 100644 --- a/tfjs-layers/src/layers/recurrent.ts +++ b/tfjs-layers/src/layers/recurrent.ts @@ -807,6 +807,8 @@ export class RNN extends Layer { } getConfig(): serialization.ConfigDict { + const baseConfig = super.getConfig(); + const config: serialization.ConfigDict = { returnSequences: this.returnSequences, returnState: this.returnState, @@ -814,17 +816,22 @@ export class RNN extends Layer { stateful: this.stateful, unroll: this.unroll, }; + if (this.numConstants != null) { config['numConstants'] = this.numConstants; } + const cellConfig = this.cell.getConfig(); - config['cell'] = { - 'className': this.cell.getClassName(), - 'config': cellConfig, - } as serialization.ConfigDictValue; - const baseConfig = super.getConfig(); - Object.assign(config, baseConfig); - return config; + + if (this.getClassName() === RNN.className) { + config['cell'] = { + 'className': this.cell.getClassName(), + 'config': cellConfig, + } as serialization.ConfigDictValue; + } + + // this order is necessary, to prevent cell name from replacing layer name + return {...cellConfig, ...baseConfig, ...config}; } /** @nocollapse */ @@ -1082,6 +1089,8 @@ export class SimpleRNNCell extends RNNCell { } getConfig(): serialization.ConfigDict { + const baseConfig = super.getConfig(); + const config: serialization.ConfigDict = { units: this.units, activation: serializeActivation(this.activation), @@ -1099,9 +1108,8 @@ export class SimpleRNNCell extends RNNCell { dropout: this.dropout, recurrentDropout: this.recurrentDropout, }; - const baseConfig = super.getConfig(); - Object.assign(config, baseConfig); - return config; + + return {...baseConfig, ...config}; } } serialization.registerClass(SimpleRNNCell); @@ -1222,88 +1230,6 @@ export class SimpleRNN extends RNN { }); } - // TODO(cais): Research possibility of refactoring out the tedious all - // the getters that delegate to `this.cell` below. - get units(): number { - return (this.cell as SimpleRNNCell).units; - } - - get activation(): Activation { - return (this.cell as SimpleRNNCell).activation; - } - - get useBias(): boolean { - return (this.cell as SimpleRNNCell).useBias; - } - - get kernelInitializer(): Initializer { - return (this.cell as SimpleRNNCell).kernelInitializer; - } - - get recurrentInitializer(): Initializer { - return (this.cell as SimpleRNNCell).recurrentInitializer; - } - - get biasInitializer(): Initializer { - return (this.cell as SimpleRNNCell).biasInitializer; - } - - get kernelRegularizer(): Regularizer { - return (this.cell as SimpleRNNCell).kernelRegularizer; - } - - get recurrentRegularizer(): Regularizer { - return (this.cell as SimpleRNNCell).recurrentRegularizer; - } - - get biasRegularizer(): Regularizer { - return (this.cell as SimpleRNNCell).biasRegularizer; - } - - get kernelConstraint(): Constraint { - return (this.cell as SimpleRNNCell).kernelConstraint; - } - - get recurrentConstraint(): Constraint { - return (this.cell as SimpleRNNCell).recurrentConstraint; - } - - get biasConstraint(): Constraint { - return (this.cell as SimpleRNNCell).biasConstraint; - } - - get dropout(): number { - return (this.cell as SimpleRNNCell).dropout; - } - - get recurrentDropout(): number { - return (this.cell as SimpleRNNCell).recurrentDropout; - } - - getConfig(): serialization.ConfigDict { - const config: serialization.ConfigDict = { - units: this.units, - activation: serializeActivation(this.activation), - useBias: this.useBias, - kernelInitializer: serializeInitializer(this.kernelInitializer), - recurrentInitializer: serializeInitializer(this.recurrentInitializer), - biasInitializer: serializeInitializer(this.biasInitializer), - kernelRegularizer: serializeRegularizer(this.kernelRegularizer), - recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer), - biasRegularizer: serializeRegularizer(this.biasRegularizer), - activityRegularizer: serializeRegularizer(this.activityRegularizer), - kernelConstraint: serializeConstraint(this.kernelConstraint), - recurrentConstraint: serializeConstraint(this.recurrentConstraint), - biasConstraint: serializeConstraint(this.biasConstraint), - dropout: this.dropout, - recurrentDropout: this.recurrentDropout, - }; - const baseConfig = super.getConfig(); - delete baseConfig['cell']; - Object.assign(config, baseConfig); - return config; - } - /** @nocollapse */ static fromConfig( cls: serialization.SerializableConstructor, @@ -1526,6 +1452,8 @@ export class GRUCell extends RNNCell { } getConfig(): serialization.ConfigDict { + const baseConfig = super.getConfig(); + const config: serialization.ConfigDict = { units: this.units, activation: serializeActivation(this.activation), @@ -1546,9 +1474,8 @@ export class GRUCell extends RNNCell { implementation: this.implementation, resetAfter: false }; - const baseConfig = super.getConfig(); - Object.assign(config, baseConfig); - return config; + + return {...baseConfig, ...config}; } } serialization.registerClass(GRUCell); @@ -1613,97 +1540,6 @@ export class GRU extends RNN { }); } - get units(): number { - return (this.cell as GRUCell).units; - } - - get activation(): Activation { - return (this.cell as GRUCell).activation; - } - - get recurrentActivation(): Activation { - return (this.cell as GRUCell).recurrentActivation; - } - - get useBias(): boolean { - return (this.cell as GRUCell).useBias; - } - - get kernelInitializer(): Initializer { - return (this.cell as GRUCell).kernelInitializer; - } - - get recurrentInitializer(): Initializer { - return (this.cell as GRUCell).recurrentInitializer; - } - - get biasInitializer(): Initializer { - return (this.cell as GRUCell).biasInitializer; - } - - get kernelRegularizer(): Regularizer { - return (this.cell as GRUCell).kernelRegularizer; - } - - get recurrentRegularizer(): Regularizer { - return (this.cell as GRUCell).recurrentRegularizer; - } - - get biasRegularizer(): Regularizer { - return (this.cell as GRUCell).biasRegularizer; - } - - get kernelConstraint(): Constraint { - return (this.cell as GRUCell).kernelConstraint; - } - - get recurrentConstraint(): Constraint { - return (this.cell as GRUCell).recurrentConstraint; - } - - get biasConstraint(): Constraint { - return (this.cell as GRUCell).biasConstraint; - } - - get dropout(): number { - return (this.cell as GRUCell).dropout; - } - - get recurrentDropout(): number { - return (this.cell as GRUCell).recurrentDropout; - } - - get implementation(): number { - return (this.cell as GRUCell).implementation; - } - - getConfig(): serialization.ConfigDict { - const config: serialization.ConfigDict = { - units: this.units, - activation: serializeActivation(this.activation), - recurrentActivation: serializeActivation(this.recurrentActivation), - useBias: this.useBias, - kernelInitializer: serializeInitializer(this.kernelInitializer), - recurrentInitializer: serializeInitializer(this.recurrentInitializer), - biasInitializer: serializeInitializer(this.biasInitializer), - kernelRegularizer: serializeRegularizer(this.kernelRegularizer), - recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer), - biasRegularizer: serializeRegularizer(this.biasRegularizer), - activityRegularizer: serializeRegularizer(this.activityRegularizer), - kernelConstraint: serializeConstraint(this.kernelConstraint), - recurrentConstraint: serializeConstraint(this.recurrentConstraint), - biasConstraint: serializeConstraint(this.biasConstraint), - dropout: this.dropout, - recurrentDropout: this.recurrentDropout, - implementation: this.implementation, - resetAfter: false - }; - const baseConfig = super.getConfig(); - delete baseConfig['cell']; - Object.assign(config, baseConfig); - return config; - } - /** @nocollapse */ static fromConfig( cls: serialization.SerializableConstructor, @@ -1943,6 +1779,8 @@ export class LSTMCell extends RNNCell { } getConfig(): serialization.ConfigDict { + const baseConfig = super.getConfig(); + const config: serialization.ConfigDict = { units: this.units, activation: serializeActivation(this.activation), @@ -1963,9 +1801,8 @@ export class LSTMCell extends RNNCell { recurrentDropout: this.recurrentDropout, implementation: this.implementation, }; - const baseConfig = super.getConfig(); - Object.assign(config, baseConfig); - return config; + + return {...baseConfig, ...config}; } } serialization.registerClass(LSTMCell); @@ -2037,101 +1874,6 @@ export class LSTM extends RNN { }); } - get units(): number { - return (this.cell as LSTMCell).units; - } - - get activation(): Activation { - return (this.cell as LSTMCell).activation; - } - - get recurrentActivation(): Activation { - return (this.cell as LSTMCell).recurrentActivation; - } - - get useBias(): boolean { - return (this.cell as LSTMCell).useBias; - } - - get kernelInitializer(): Initializer { - return (this.cell as LSTMCell).kernelInitializer; - } - - get recurrentInitializer(): Initializer { - return (this.cell as LSTMCell).recurrentInitializer; - } - - get biasInitializer(): Initializer { - return (this.cell as LSTMCell).biasInitializer; - } - - get unitForgetBias(): boolean { - return (this.cell as LSTMCell).unitForgetBias; - } - - get kernelRegularizer(): Regularizer { - return (this.cell as LSTMCell).kernelRegularizer; - } - - get recurrentRegularizer(): Regularizer { - return (this.cell as LSTMCell).recurrentRegularizer; - } - - get biasRegularizer(): Regularizer { - return (this.cell as LSTMCell).biasRegularizer; - } - - get kernelConstraint(): Constraint { - return (this.cell as LSTMCell).kernelConstraint; - } - - get recurrentConstraint(): Constraint { - return (this.cell as LSTMCell).recurrentConstraint; - } - - get biasConstraint(): Constraint { - return (this.cell as LSTMCell).biasConstraint; - } - - get dropout(): number { - return (this.cell as LSTMCell).dropout; - } - - get recurrentDropout(): number { - return (this.cell as LSTMCell).recurrentDropout; - } - - get implementation(): number { - return (this.cell as LSTMCell).implementation; - } - - getConfig(): serialization.ConfigDict { - const config: serialization.ConfigDict = { - units: this.units, - activation: serializeActivation(this.activation), - recurrentActivation: serializeActivation(this.recurrentActivation), - useBias: this.useBias, - kernelInitializer: serializeInitializer(this.kernelInitializer), - recurrentInitializer: serializeInitializer(this.recurrentInitializer), - biasInitializer: serializeInitializer(this.biasInitializer), - unitForgetBias: this.unitForgetBias, - kernelRegularizer: serializeRegularizer(this.kernelRegularizer), - recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer), - biasRegularizer: serializeRegularizer(this.biasRegularizer), - activityRegularizer: serializeRegularizer(this.activityRegularizer), - kernelConstraint: serializeConstraint(this.kernelConstraint), - recurrentConstraint: serializeConstraint(this.recurrentConstraint), - biasConstraint: serializeConstraint(this.biasConstraint), - dropout: this.dropout, - recurrentDropout: this.recurrentDropout, - implementation: this.implementation, - }; - const baseConfig = super.getConfig(); - delete baseConfig['cell']; - Object.assign(config, baseConfig); - return config; - } - /** @nocollapse */ static fromConfig( cls: serialization.SerializableConstructor, @@ -2243,17 +1985,20 @@ export class StackedRNNCells extends RNNCell { } getConfig(): serialization.ConfigDict { - const cellConfigs: serialization.ConfigDict[] = []; - for (const cell of this.cells) { - cellConfigs.push({ + const baseConfig = super.getConfig(); + + const getCellConfig = (cell: RNNCell) => { + return { 'className': cell.getClassName(), 'config': cell.getConfig(), - }); - } - const config: serialization.ConfigDict = {'cells': cellConfigs}; - const baseConfig = super.getConfig(); - Object.assign(config, baseConfig); - return config; + }; + }; + + const cellConfigs = this.cells.map(getCellConfig); + + const config = {'cells': cellConfigs}; + + return {...baseConfig, ...config}; } /** @nocollapse */ diff --git a/tfjs-layers/src/layers/recurrent_test.ts b/tfjs-layers/src/layers/recurrent_test.ts index 17688110176..19d5f520a9f 100644 --- a/tfjs-layers/src/layers/recurrent_test.ts +++ b/tfjs-layers/src/layers/recurrent_test.ts @@ -15,15 +15,19 @@ import * as tfc from '@tensorflow/tfjs-core'; import {io, randomNormal, scalar, Tensor, tensor1d, tensor2d, tensor3d, tensor4d} from '@tensorflow/tfjs-core'; +import {serializeActivation, Tanh} from '../activations'; import * as K from '../backend/tfjs_backend'; +import {NonNeg, serializeConstraint, UnitNorm} from '../constraints'; import * as tfl from '../index'; +import {GlorotUniform, HeUniform, Ones, serializeInitializer} from '../initializers'; import {ActivationIdentifier} from '../keras_format/activation_config'; import {ModelAndWeightsConfig, modelFromJSON} from '../models'; +import {L1L2, serializeRegularizer} from '../regularizers'; import {Kwargs} from '../types'; import {convertPythonicToTs, convertTsToPythonic} from '../utils/serialization_utils'; import {describeMathCPU, describeMathCPUAndGPU, describeMathGPU, expectTensorsClose} from '../utils/test_utils'; -import {GRU, LSTM, rnn, RNN, RNNCell} from './recurrent'; +import {GRU, GRUCellLayerArgs, GRULayerArgs, LSTM, LSTMCellLayerArgs, LSTMLayerArgs, rnn, RNN, RNNCell, SimpleRNNCellLayerArgs, SimpleRNNLayerArgs} from './recurrent'; /** * A simplistic RNN step function for testing. @@ -47,13 +51,10 @@ describeMathCPUAndGPU('rnn', () => { const inputs = tensor3d( [[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], [2, 3, 2]); const initialStates = [tfc.zeros([2, 4])]; - const rnnOutputs = rnn( - rnnStepForTest, inputs, initialStates, - false /* goBackwards */, - null /* mask */, - null /* constants */, - false /* unroll */, - true /* needPerStepOutputs */); + const rnnOutputs = + rnn(rnnStepForTest, inputs, initialStates, false /* goBackwards */, + null /* mask */, null /* constants */, false /* unroll */, + true /* needPerStepOutputs */); const lastOutput = rnnOutputs[0]; const outputs = rnnOutputs[1]; const newStates = rnnOutputs[2]; @@ -61,8 +62,7 @@ describeMathCPUAndGPU('rnn', () => { lastOutput, tensor2d( [ - [-57.75, -57.75, -57.75, -57.75], - [-57.75, -57.75, -57.75, -57.75] + [-57.75, -57.75, -57.75, -57.75], [-57.75, -57.75, -57.75, -57.75] ], [2, 4])); expectTensorsClose( @@ -92,13 +92,10 @@ describeMathCPUAndGPU('rnn', () => { [[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], [2, 3, 2]); // The two state tensors have different shapes. const initialStates = [tfc.zeros([2, 4]), tfc.ones([2, 3])]; - const rnnOutputs = rnn( - rnnStepForTest, inputs, initialStates, - false /* goBackwards */, - null /* mask */, - null /* constants */, - false /* unroll */, - true /* needPerStepOutputs */); + const rnnOutputs = + rnn(rnnStepForTest, inputs, initialStates, false /* goBackwards */, + null /* mask */, null /* constants */, false /* unroll */, + true /* needPerStepOutputs */); const lastOutput = rnnOutputs[0]; const outputs = rnnOutputs[1]; const newStates = rnnOutputs[2]; @@ -106,8 +103,7 @@ describeMathCPUAndGPU('rnn', () => { lastOutput, tensor2d( [ - [-57.75, -57.75, -57.75, -57.75], - [-57.75, -57.75, -57.75, -57.75] + [-57.75, -57.75, -57.75, -57.75], [-57.75, -57.75, -57.75, -57.75] ], [2, 4])); expectTensorsClose( @@ -144,13 +140,10 @@ describeMathCPUAndGPU('rnn', () => { [2, 3, 2, 1]); // The two state tensors have different shapes. const initialStates = [tfc.zeros([2, 4]), tfc.ones([2, 3])]; - const rnnOutputs = rnn( - rnnStepForTest, inputs, initialStates, - false /* goBackwards */, - null /* mask */, - null /* constants */, - false /* unroll */, - true /* needPerStepOutputs */); + const rnnOutputs = + rnn(rnnStepForTest, inputs, initialStates, false /* goBackwards */, + null /* mask */, null /* constants */, false /* unroll */, + true /* needPerStepOutputs */); const lastOutput = rnnOutputs[0]; const outputs = rnnOutputs[1]; const newStates = rnnOutputs[2]; @@ -158,8 +151,7 @@ describeMathCPUAndGPU('rnn', () => { lastOutput, tensor2d( [ - [-57.75, -57.75, -57.75, -57.75], - [-57.75, -57.75, -57.75, -57.75] + [-57.75, -57.75, -57.75, -57.75], [-57.75, -57.75, -57.75, -57.75] ], [2, 4])); expectTensorsClose( @@ -371,8 +363,7 @@ describeMathCPUAndGPU('RNN-Layer-Math', () => { const inputs = tensor3d( [[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], [2, 3, 2]); const outputs = rnn.apply(inputs) as Tensor; - expectTensorsClose( - outputs, tfc.mul(scalar(-57.75), tfc.ones([2, 4]))); + expectTensorsClose(outputs, tfc.mul(scalar(-57.75), tfc.ones([2, 4]))); }); it('apply: 1 state: returnSequences=true, returnState=false', () => { @@ -463,8 +454,7 @@ describeMathCPUAndGPU('RNN-Layer-Math', () => { [[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], [2, 3, 2]); const outputs = rnn.apply(inputs, {'initialState': [tfc.ones([2, 4])]}) as Tensor; - expectTensorsClose( - outputs, tfc.mul(scalar(-58.75), tfc.ones([2, 4]))); + expectTensorsClose(outputs, tfc.mul(scalar(-58.75), tfc.ones([2, 4]))); }); it('call: with 2 initialStates', () => { @@ -473,17 +463,12 @@ describeMathCPUAndGPU('RNN-Layer-Math', () => { const inputs = tensor3d( [[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], [2, 3, 2]); const outputs = rnn.apply(inputs, { - 'initialState': [ - tfc.ones([2, 4]), tfc.mul(scalar(2), tfc.ones([2, 5])) - ] + 'initialState': [tfc.ones([2, 4]), tfc.mul(scalar(2), tfc.ones([2, 5]))] }) as Tensor[]; expect(outputs.length).toEqual(3); - expectTensorsClose( - outputs[0], tfc.mul(scalar(-58.75), tfc.ones([2, 4]))); - expectTensorsClose( - outputs[1], tfc.mul(scalar(58.75), tfc.ones([2, 4]))); - expectTensorsClose( - outputs[2], tfc.mul(scalar(59.75), tfc.ones([2, 5]))); + expectTensorsClose(outputs[0], tfc.mul(scalar(-58.75), tfc.ones([2, 4]))); + expectTensorsClose(outputs[1], tfc.mul(scalar(58.75), tfc.ones([2, 4]))); + expectTensorsClose(outputs[2], tfc.mul(scalar(59.75), tfc.ones([2, 5]))); }); it('call with incorrect number of initialStates leads to ValueError', () => { @@ -544,12 +529,15 @@ describeMathCPU('SimpleRNN Symbolic', () => { }); it('Invalid units leads to Error', () => { - expect(() => tfl.layers.simpleRNN({units: 12.5})) - .toThrowError(/units.*positive integer.*12\.5\.$/); - expect(() => tfl.layers.simpleRNN({units: 0})) - .toThrowError(/units.*positive integer.*0\.$/); - expect(() => tfl.layers.simpleRNN({units: -25})) - .toThrowError(/units.*positive integer.*-25\.$/); + expect(() => tfl.layers.simpleRNN({ + units: 12.5 + })).toThrowError(/units.*positive integer.*12\.5\.$/); + expect(() => tfl.layers.simpleRNN({ + units: 0 + })).toThrowError(/units.*positive integer.*0\.$/); + expect(() => tfl.layers.simpleRNN({ + units: -25 + })).toThrowError(/units.*positive integer.*-25\.$/); }); }); @@ -577,17 +565,17 @@ describeMathCPUAndGPU('SimpleRNN Tensor', () => { biasInitializer: 'ones', dropout, }); - const kwargs:Kwargs = {}; + const kwargs: Kwargs = {}; if (training) { kwargs['training'] = true; } const input = tfc.ones([batchSize, timeSteps, inputSize]); spyOn(tfc, 'dropout').and.callThrough(); let numTensors = 0; - for (let i = 0; i < 2; i++){ + for (let i = 0; i < 2; i++) { tfc.dispose(simpleRNN.apply(input, kwargs) as Tensor); if (dropout !== 0.0 && training) { - expect(tfc.dropout).toHaveBeenCalledTimes(1 * (i + 1)); + expect(tfc.dropout).toHaveBeenCalledTimes(1 * (i + 1)); } else { expect(tfc.dropout).toHaveBeenCalledTimes(0); } @@ -616,19 +604,19 @@ describeMathCPUAndGPU('SimpleRNN Tensor', () => { biasInitializer: 'ones', recurrentDropout, }); - const kwargs:Kwargs = {}; + const kwargs: Kwargs = {}; if (training) { kwargs['training'] = true; } const input = tfc.ones([batchSize, timeSteps, inputSize]); spyOn(tfc, 'dropout').and.callThrough(); let numTensors = 0; - for (let i = 0; i < 2; i++){ + for (let i = 0; i < 2; i++) { tfc.dispose(simpleRNN.apply(input, kwargs) as Tensor); if (recurrentDropout !== 0.0 && training) { expect(tfc.dropout).toHaveBeenCalledTimes(1 * (i + 1)); } else { - expect(tfc.dropout).toHaveBeenCalledTimes(0); + expect(tfc.dropout).toHaveBeenCalledTimes(0); } if (i === 0) { numTensors = tfc.memory().numTensors; @@ -640,7 +628,7 @@ describeMathCPUAndGPU('SimpleRNN Tensor', () => { } } - const activations : ActivationIdentifier[] = ['linear', 'tanh']; + const activations: ActivationIdentifier[] = ['linear', 'tanh']; for (const activation of activations) { const testTitle = `returnSequences=false, returnState=false, useBias=true, ${activation}`; @@ -661,8 +649,7 @@ describeMathCPUAndGPU('SimpleRNN Tensor', () => { } expectTensorsClose( output, - tfc.mul( - scalar(expectedElementValue), tfc.ones([batchSize, units]))); + tfc.mul(scalar(expectedElementValue), tfc.ones([batchSize, units]))); }); } @@ -699,8 +686,7 @@ describeMathCPUAndGPU('SimpleRNN Tensor', () => { const outputT1 = K.sliceAlongFirstAxis(timeMajorOutput, 1, 1); expectTensorsClose( outputT0, - tfc.mul( - scalar(inputSize + 1), tfc.ones([1, batchSize, units]))); + tfc.mul(scalar(inputSize + 1), tfc.ones([1, batchSize, units]))); expectTensorsClose( outputT1, tfc.mul( @@ -711,7 +697,7 @@ describeMathCPUAndGPU('SimpleRNN Tensor', () => { } }); - it('stateful missing batchInputShape leads to error' , () => { + it('stateful missing batchInputShape leads to error', () => { const sequenceLength = 3; const simpleRNN = tfl.layers.simpleRNN({ units, @@ -723,8 +709,8 @@ describeMathCPUAndGPU('SimpleRNN Tensor', () => { inputShape: [sequenceLength, inputSize] }) as RNN; const model = tfl.sequential(); - expect(() => model.add(simpleRNN)).toThrowError( - /needs to know its batch size/); + expect(() => model.add(simpleRNN)) + .toThrowError(/needs to know its batch size/); }); // The reference values below can be obtained with PyKeras code: @@ -849,10 +835,7 @@ describeMathCPUAndGPU('SimpleRNN Tensor', () => { stateful: true, batchInputShape: [batchSize, sequenceLength, inputSize] })); - model.add(tfl.layers.dense({ - units: 1, - kernelInitializer: 'ones' - })); + model.add(tfl.layers.dense({units: 1, kernelInitializer: 'ones'})); model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); const xs = tfc.ones([batchSize, sequenceLength, inputSize]); const ys = tfc.ones([batchSize, 1]); @@ -971,11 +954,8 @@ describeMathCPUAndGPU('SimpleRNN Tensor', () => { stateful: true, batchInputShape: [batchSize, sequenceLength, inputSize] })); - model.add(tfl.layers.dense({ - units: 1, - kernelInitializer: 'zeros', - useBias: false - })); + model.add(tfl.layers.dense( + {units: 1, kernelInitializer: 'zeros', useBias: false})); model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); const xs1 = tfc.ones([batchSize, sequenceLength, inputSize]); @@ -1054,6 +1034,98 @@ describeMathCPUAndGPU('SimpleRNN Tensor', () => { }); }); +describeMathCPU('SimpleRNN Serialization', () => { + const cellConfig: SimpleRNNCellLayerArgs = { + units: 8, + activation: 'tanh', + useBias: true, + kernelInitializer: 'glorotUniform', + recurrentInitializer: 'heUniform', + biasInitializer: 'ones', + kernelRegularizer: 'l1l2', + recurrentRegularizer: 'l1l2', + biasRegularizer: 'l1l2', + kernelConstraint: 'unitNorm', + recurrentConstraint: 'unitNorm', + biasConstraint: 'nonNeg', + dropout: 0.1, + recurrentDropout: 0.2, + name: 'cell_1', + batchSize: 12, + batchInputShape: [12, 8, 8], + inputShape: [8, 8], + dtype: 'int32', + inputDType: 'int32', + trainable: true, + }; + + const expectedCellConfigPrime = { + name: 'cell_1', + trainable: true, + batchInputShape: [12, 8, 8], + dtype: 'int32', + units: 8, + activation: serializeActivation(new Tanh()), + useBias: true, + kernelInitializer: serializeInitializer(new GlorotUniform()), + recurrentInitializer: serializeInitializer(new HeUniform()), + biasInitializer: serializeInitializer(new Ones()), + kernelRegularizer: serializeRegularizer(new L1L2()), + recurrentRegularizer: serializeRegularizer(new L1L2()), + biasRegularizer: serializeRegularizer(new L1L2()), + activityRegularizer: serializeRegularizer(null), + kernelConstraint: serializeConstraint(new UnitNorm({})), + recurrentConstraint: serializeConstraint(new UnitNorm({})), + biasConstraint: serializeConstraint(new NonNeg()), + }; + + describe('SimpleRNNCell.getConfig', () => { + it('should return the expected values', () => { + const cell = tfl.layers.simpleRNNCell(cellConfig); + + const {dropout, recurrentDropout, ...configPrime} = cell.getConfig(); + + expect(configPrime).toEqual(expectedCellConfigPrime); + expect(dropout).toBeCloseTo(0.1); + expect(recurrentDropout).toBeCloseTo(0.2); + }); + }); + + describe('SimpleRNN.getConfig', () => { + it('should return the expected values', () => { + const config: SimpleRNNLayerArgs = { + ...cellConfig, + name: 'layer_1', + ...{ + returnSequences: true, + returnState: true, + stateful: true, + unroll: true, + goBackwards: true, + inputDim: 8, + inputLength: 8, + } as Omit + }; + + const cell = tfl.layers.simpleRNN(config); + + const {dropout, recurrentDropout, ...configPrime} = cell.getConfig(); + + expect(configPrime).toEqual({ + ...expectedCellConfigPrime, + name: 'layer_1', + returnSequences: true, + returnState: true, + stateful: true, + unroll: true, + goBackwards: true, + }); + expect(dropout).toBeCloseTo(0.1); + expect(recurrentDropout).toBeCloseTo(0.2); + }); + }); +}); + describeMathCPU('GRU Symbolic', () => { it('returnSequences=false, returnState=false', () => { const input = new tfl.SymbolicTensor('float32', [9, 10, 8], null, [], null); @@ -1117,12 +1189,15 @@ describeMathCPU('GRU Symbolic', () => { } it('Invalid units leads to Error', () => { - expect(() => tfl.layers.gru({units: 12.5})) - .toThrowError(/units.*positive integer.*12\.5\.$/); - expect(() => tfl.layers.gru({units: 0})) - .toThrowError(/units.*positive integer.*0\.$/); - expect(() => tfl.layers.gru({units: -25})) - .toThrowError(/units.*positive integer.*-25\.$/); + expect(() => tfl.layers.gru({ + units: 12.5 + })).toThrowError(/units.*positive integer.*12\.5\.$/); + expect(() => tfl.layers.gru({ + units: 0 + })).toThrowError(/units.*positive integer.*0\.$/); + expect(() => tfl.layers.gru({ + units: -25 + })).toThrowError(/units.*positive integer.*-25\.$/); }); }); @@ -1181,17 +1256,17 @@ describeMathCPUAndGPU('GRU Tensor', () => { dropout, implementation: 1 }); - const kwargs:Kwargs = {}; + const kwargs: Kwargs = {}; if (training) { kwargs['training'] = true; } const input = tfc.ones([batchSize, timeSteps, inputSize]); spyOn(tfc, 'dropout').and.callThrough(); let numTensors = 0; - for (let i = 0; i < 2; i++){ + for (let i = 0; i < 2; i++) { tfc.dispose(gru.apply(input, kwargs) as Tensor); if (dropout !== 0.0 && training) { - expect(tfc.dropout).toHaveBeenCalledTimes(3 * (i + 1)); + expect(tfc.dropout).toHaveBeenCalledTimes(3 * (i + 1)); } else { expect(tfc.dropout).toHaveBeenCalledTimes(0); } @@ -1220,14 +1295,14 @@ describeMathCPUAndGPU('GRU Tensor', () => { recurrentDropout, implementation: 1 }); - const kwargs:Kwargs = {}; + const kwargs: Kwargs = {}; if (training) { kwargs['training'] = true; } const input = tfc.ones([batchSize, timeSteps, inputSize]); spyOn(tfc, 'dropout').and.callThrough(); let numTensors = 0; - for (let i = 0; i < 2; i++){ + for (let i = 0; i < 2; i++) { tfc.dispose(gru.apply(input, kwargs) as Tensor); if (recurrentDropout !== 0.0 && training) { expect(tfc.dropout).toHaveBeenCalledTimes(3 * (i + 1)); @@ -1272,8 +1347,8 @@ describeMathCPUAndGPU('GRU Tensor', () => { let expectedOutput: Tensor; if (returnSequences) { const outputs = goldenOutputElementValues.map( - value => tfc.mul( - scalar(value), tfc.ones([1, batchSize, units]))); + value => + tfc.mul(scalar(value), tfc.ones([1, batchSize, units]))); expectedOutput = tfc.transpose( K.concatAlongFirstAxis( K.concatAlongFirstAxis(outputs[0], outputs[1]), outputs[2]), @@ -1420,11 +1495,8 @@ describeMathCPUAndGPU('GRU Tensor', () => { stateful: true, batchInputShape: [batchSize, sequenceLength, inputSize] })); - model.add(tfl.layers.dense({ - units: 1, - kernelInitializer: 'zeros', - useBias: false - })); + model.add(tfl.layers.dense( + {units: 1, kernelInitializer: 'zeros', useBias: false})); model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); const xs1 = tfc.ones([batchSize, sequenceLength, inputSize]); @@ -1479,11 +1551,8 @@ describeMathCPUAndGPU('GRU Tensor', () => { useBias: false, inputShape: [sequenceLength, inputSize] })); - model.add(tfl.layers.dense({ - units: 1, - kernelInitializer: 'ones', - useBias: false - })); + model.add(tfl.layers.dense( + {units: 1, kernelInitializer: 'ones', useBias: false})); const sgd = tfc.train.sgd(1); model.compile({loss: 'meanSquaredError', optimizer: sgd}); @@ -1523,25 +1592,27 @@ describeMathCPUAndGPU('GRU Tensor', () => { // ``` it('SymbolicTensor as initialState thru kwargs; Save & Load', async () => { const in1 = tfl.input({shape: [5]}); - const out1 = tfl.layers.dense({ - units: 2, - kernelInitializer: 'ones', - biasInitializer: 'zeros' - }).apply(in1) as tfl.SymbolicTensor; + const out1 = + tfl.layers + .dense( + {units: 2, kernelInitializer: 'ones', biasInitializer: 'zeros'}) + .apply(in1) as tfl.SymbolicTensor; const in2 = tfl.input({shape: [3, 4]}); - const out2 = tfl.layers.gru({ - units: 2, - recurrentInitializer: 'ones', - kernelInitializer: 'ones', - biasInitializer: 'zeros' - }).apply(in2, {initialState: out1}) as tfl.SymbolicTensor; + const out2 = tfl.layers + .gru({ + units: 2, + recurrentInitializer: 'ones', + kernelInitializer: 'ones', + biasInitializer: 'zeros' + }) + .apply(in2, {initialState: out1}) as tfl.SymbolicTensor; - const model = tfl.model({inputs: [in1 , in2], outputs: [out1, out2]}); + const model = tfl.model({inputs: [in1, in2], outputs: [out1, out2]}); const xs1 = tensor2d([[0.1, 0.2, 0.3, 0.4, 0.5]]); - const xs2 = tensor3d( - [[[0.1, 0.2, 0.3, 0.4], [-0.1, -0.2, -0.3, -0.4], - [0.3, 0.4, 0.5, 0.6]]]); + const xs2 = tensor3d([ + [[0.1, 0.2, 0.3, 0.4], [-0.1, -0.2, -0.3, -0.4], [0.3, 0.4, 0.5, 0.6]] + ]); const ys = model.predict([xs1, xs2]) as Tensor[]; expect(ys.length).toEqual(2); expectTensorsClose(ys[0], tensor2d([[1.5, 1.5]])); @@ -1571,8 +1642,8 @@ describeMathCPU('GRU-deserialization', () => { const layerPrime = tfl.layers.gru(tsConfig) as GRU; const yPrime = layer.apply(x) as Tensor; expectTensorsClose(yPrime, y); - expect(layerPrime.recurrentActivation.getClassName()) - .toEqual(layer.recurrentActivation.getClassName()); + expect(layerPrime.getConfig()['recurrentActivation']) + .toEqual(layer.getConfig()['recurrentActivation']); }); it('Non-default recurrentActivation round trip', () => { @@ -1586,8 +1657,105 @@ describeMathCPU('GRU-deserialization', () => { const layerPrime = tfl.layers.gru(tsConfig) as GRU; const yPrime = layer.apply(x) as Tensor; expectTensorsClose(yPrime, y); - expect(layerPrime.recurrentActivation.getClassName()) - .toEqual(layer.recurrentActivation.getClassName()); + expect(layerPrime.getConfig()['recurrentActivation']) + .toEqual(layer.getConfig()['recurrentActivation']); + }); +}); + +describeMathCPU('GRU Serialization', () => { + const cellConfig: GRUCellLayerArgs = { + units: 8, + activation: 'tanh', + recurrentActivation: 'tanh', + useBias: true, + kernelInitializer: 'glorotUniform', + recurrentInitializer: 'heUniform', + biasInitializer: 'ones', + kernelRegularizer: 'l1l2', + recurrentRegularizer: 'l1l2', + biasRegularizer: 'l1l2', + kernelConstraint: 'unitNorm', + recurrentConstraint: 'unitNorm', + biasConstraint: 'nonNeg', + dropout: 0.1, + recurrentDropout: 0.2, + name: 'cell_1', + batchSize: 12, + batchInputShape: [12, 8, 8], + inputShape: [8, 8], + dtype: 'int32', + inputDType: 'int32', + trainable: true, + implementation: 1, + }; + + const expectedCellConfigPrime = { + name: 'cell_1', + trainable: true, + batchInputShape: [12, 8, 8], + dtype: 'int32', + units: 8, + activation: serializeActivation(new Tanh()), + recurrentActivation: serializeActivation(new Tanh()), + useBias: true, + kernelInitializer: serializeInitializer(new GlorotUniform()), + recurrentInitializer: serializeInitializer(new HeUniform()), + biasInitializer: serializeInitializer(new Ones()), + kernelRegularizer: serializeRegularizer(new L1L2()), + recurrentRegularizer: serializeRegularizer(new L1L2()), + biasRegularizer: serializeRegularizer(new L1L2()), + activityRegularizer: serializeRegularizer(null), + kernelConstraint: serializeConstraint(new UnitNorm({})), + recurrentConstraint: serializeConstraint(new UnitNorm({})), + biasConstraint: serializeConstraint(new NonNeg()), + implementation: 1, + resetAfter: false, + }; + + describe('GRUCell.getConfig', () => { + it('should return the expected values', () => { + const cell = tfl.layers.gruCell(cellConfig); + + const {dropout, recurrentDropout, ...configPrime} = cell.getConfig(); + + expect(configPrime).toEqual(expectedCellConfigPrime); + expect(dropout).toBeCloseTo(0.1); + expect(recurrentDropout).toBeCloseTo(0.2); + }); + }); + + describe('GRU.getConfig', () => { + it('should return the expected values', () => { + const config: GRULayerArgs = { + ...cellConfig, + name: 'layer_1', + ...{ + returnSequences: true, + returnState: true, + stateful: true, + unroll: true, + goBackwards: true, + inputDim: 8, + inputLength: 8, + } as Omit + }; + + const cell = tfl.layers.gru(config); + + const {dropout, recurrentDropout, ...configPrime} = cell.getConfig(); + + expect(configPrime).toEqual({ + ...expectedCellConfigPrime, + name: 'layer_1', + returnSequences: true, + returnState: true, + stateful: true, + unroll: true, + goBackwards: true, + }); + expect(dropout).toBeCloseTo(0.1); + expect(recurrentDropout).toBeCloseTo(0.2); + }); }); }); @@ -1676,8 +1844,8 @@ describeMathCPU('LSTM Symbolic', () => { return null; })); - const loadedModel = await tfl.loadLayersModel( - tfc.io.fromMemory(savedArtifacts)); + const loadedModel = + await tfl.loadLayersModel(tfc.io.fromMemory(savedArtifacts)); expect(model.inputs[0].shape).toEqual(loadedModel.inputs[0].shape); expect(model.outputs[0].shape).toEqual(loadedModel.outputs[0].shape); @@ -1685,12 +1853,15 @@ describeMathCPU('LSTM Symbolic', () => { }); it('Invalid units leads to Error', () => { - expect(() => tfl.layers.lstm({units: 12.5})) - .toThrowError(/units.*positive integer.*12\.5\.$/); - expect(() => tfl.layers.lstm({units: 0})) - .toThrowError(/units.*positive integer.*0\.$/); - expect(() => tfl.layers.lstm({units: -25})) - .toThrowError(/units.*positive integer.*-25\.$/); + expect(() => tfl.layers.lstm({ + units: 12.5 + })).toThrowError(/units.*positive integer.*12\.5\.$/); + expect(() => tfl.layers.lstm({ + units: 0 + })).toThrowError(/units.*positive integer.*0\.$/); + expect(() => tfl.layers.lstm({ + units: -25 + })).toThrowError(/units.*positive integer.*-25\.$/); }); }); @@ -1749,14 +1920,14 @@ describeMathCPUAndGPU('LSTM Tensor', () => { dropout, implementation: 1 }); - const kwargs:Kwargs = {}; + const kwargs: Kwargs = {}; if (training) { kwargs['training'] = true; } const input = tfc.ones([batchSize, timeSteps, inputSize]); spyOn(tfc, 'dropout').and.callThrough(); let numTensors = 0; - for (let i = 0; i < 2; i++){ + for (let i = 0; i < 2; i++) { tfc.dispose(lstm.apply(input, kwargs) as Tensor); if (dropout !== 0.0 && training) { expect(tfc.dropout).toHaveBeenCalledTimes(4 * (i + 1)); @@ -1788,14 +1959,14 @@ describeMathCPUAndGPU('LSTM Tensor', () => { recurrentDropout, implementation: 1 }); - const kwargs:Kwargs = {}; + const kwargs: Kwargs = {}; if (training) { kwargs['training'] = true; } const input = tfc.ones([batchSize, timeSteps, inputSize]); spyOn(tfc, 'dropout').and.callThrough(); let numTensors = 0; - for (let i = 0; i < 2; i++){ + for (let i = 0; i < 2; i++) { tfc.dispose(lstm.apply(input, kwargs) as Tensor); if (recurrentDropout !== 0.0 && training) { expect(tfc.dropout).toHaveBeenCalledTimes(4 * (i + 1)); @@ -1996,11 +2167,8 @@ describeMathCPUAndGPU('LSTM Tensor', () => { stateful: true, batchInputShape: [batchSize, sequenceLength, inputSize] })); - model.add(tfl.layers.dense({ - units: 1, - kernelInitializer: 'ones', - useBias: false - })); + model.add(tfl.layers.dense( + {units: 1, kernelInitializer: 'ones', useBias: false})); model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); const xs1 = tfc.ones([batchSize, sequenceLength, inputSize]); @@ -2123,12 +2291,13 @@ describeMathCPUAndGPU('LSTM Tensor', () => { kernelInitializer: 'ones', biasInitializer: 'zeros' })); - model.add(tfl.layers.dense({ - units: 1, kernelInitializer: 'ones', biasInitializer: 'zeros'})); + model.add(tfl.layers.dense( + {units: 1, kernelInitializer: 'ones', biasInitializer: 'zeros'})); - const xs = tensor2d( - [[0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [1, 2, 0, 0, 0, 0], - [1, 2, 3, 0, 0, 0]]); + const xs = tensor2d([ + [0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [1, 2, 0, 0, 0, 0], + [1, 2, 3, 0, 0, 0] + ]); const ys = model.predict(xs) as Tensor; expectTensorsClose( ys, tensor2d([[0], [2.283937], [2.891939], [2.9851441]])); @@ -2169,31 +2338,23 @@ describeMathCPUAndGPU('LSTM Tensor', () => { // ``` it('With mask, goBackwards = true', () => { const model = tfl.sequential(); - const embeddingLayer = tfl.layers.embedding({ - inputDim: 4, - outputDim: 2, - inputLength: 3, - maskZero: true - }); + const embeddingLayer = tfl.layers.embedding( + {inputDim: 4, outputDim: 2, inputLength: 3, maskZero: true}); model.add(embeddingLayer); - const lstmLayer = tfl.layers.lstm({ - units: 2, - goBackwards: true - }); + const lstmLayer = tfl.layers.lstm({units: 2, goBackwards: true}); model.add(lstmLayer); - model.add(tfl.layers.dense({ - units: 1, kernelInitializer: 'ones', biasInitializer: 'zeros'})); + model.add(tfl.layers.dense( + {units: 1, kernelInitializer: 'ones', biasInitializer: 'zeros'})); // Setting weights to asymmetric, so that the effect of goBackwards=true // can show. - embeddingLayer.setWeights([ - tensor2d([[0.1, 0.2], [0.3, 0.4], [-0.1, -0.2], [-0.3, -0.4]])]); + embeddingLayer.setWeights( + [tensor2d([[0.1, 0.2], [0.3, 0.4], [-0.1, -0.2], [-0.3, -0.4]])]); lstmLayer.setWeights([ - tensor2d([[1, 2, 3, 4, 5, 6, 7, 8], - [-1, -2, -3, -4, -5, -6, -7, -8]]), - tensor2d([[1, 2, 3, 4, 5, 6, 7, 8], - [-1, -2, -3, -4, -5, -6, -7, -8]]), - tensor1d([1, 2, 3, 4, 5, 6, 7, 8])]); + tensor2d([[1, 2, 3, 4, 5, 6, 7, 8], [-1, -2, -3, -4, -5, -6, -7, -8]]), + tensor2d([[1, 2, 3, 4, 5, 6, 7, 8], [-1, -2, -3, -4, -5, -6, -7, -8]]), + tensor1d([1, 2, 3, 4, 5, 6, 7, 8]) + ]); const xs = tensor2d([[0, 0, 0], [1, 0, 0], [1, 2, 0], [1, 2, 3]]); const ys = model.predict(xs) as Tensor; @@ -2233,10 +2394,8 @@ describeMathCPUAndGPU('LSTM Tensor', () => { // ``` it('With mask and a nested model', () => { const model = tfl.sequential(); - model.add(tfl.layers.reshape({ - targetShape: [6], - inputShape: [6] - })); // A dummy input layer. + model.add(tfl.layers.reshape( + {targetShape: [6], inputShape: [6]})); // A dummy input layer. const nestedModel = tfl.sequential(); nestedModel.add(tfl.layers.embedding({ inputDim: 10, @@ -2252,12 +2411,13 @@ describeMathCPUAndGPU('LSTM Tensor', () => { biasInitializer: 'zeros' })); model.add(nestedModel); - model.add(tfl.layers.dense({ - units: 1, kernelInitializer: 'ones', biasInitializer: 'zeros'})); + model.add(tfl.layers.dense( + {units: 1, kernelInitializer: 'ones', biasInitializer: 'zeros'})); - const xs = tensor2d( - [[0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [1, 2, 0, 0, 0, 0], - [1, 2, 3, 0, 0, 0]]); + const xs = tensor2d([ + [0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [1, 2, 0, 0, 0, 0], + [1, 2, 3, 0, 0, 0] + ]); const ys = model.predict(xs) as Tensor; expectTensorsClose( ys, tensor2d([[0], [2.283937], [2.891939], [2.9851441]])); @@ -2308,13 +2468,14 @@ describeMathCPUAndGPU('LSTM Tensor', () => { kernelInitializer: 'ones', biasInitializer: 'zeros' })); - model.add(tfl.layers.dense({ - units: 1, kernelInitializer: 'ones', biasInitializer: 'zeros'})); + model.add(tfl.layers.dense( + {units: 1, kernelInitializer: 'ones', biasInitializer: 'zeros'})); model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); - const xs = tensor2d( - [[0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [1, 2, 0, 0, 0, 0], - [1, 2, 3, 0, 0, 0]]); + const xs = tensor2d([ + [0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [1, 2, 0, 0, 0, 0], + [1, 2, 3, 0, 0, 0] + ]); const ys = tensor2d([[1], [2], [3], [4]]); // Serves as burn-in call for subsequent tracking of memory leak. @@ -2358,41 +2519,47 @@ describeMathCPUAndGPU('LSTM Tensor', () => { // ``` it('With mask, returnStates = true', () => { const inp = tfl.input({shape: [6]}); - let y: tfl.SymbolicTensor|tfl.SymbolicTensor[] = tfl.layers.embedding({ - inputDim: 10, - outputDim: 4, - inputLength: 6, - maskZero: true, - embeddingsInitializer: 'ones' - }).apply(inp) as tfl.SymbolicTensor; - y = tfl.layers.lstm({ - units: 3, - returnState: true, - recurrentInitializer: 'ones', - kernelInitializer: 'ones', - biasInitializer: 'zeros' - }).apply(y) as tfl.SymbolicTensor[]; + let y: tfl.SymbolicTensor|tfl.SymbolicTensor[] = + tfl.layers + .embedding({ + inputDim: 10, + outputDim: 4, + inputLength: 6, + maskZero: true, + embeddingsInitializer: 'ones' + }) + .apply(inp) as tfl.SymbolicTensor; + y = tfl.layers + .lstm({ + units: 3, + returnState: true, + recurrentInitializer: 'ones', + kernelInitializer: 'ones', + biasInitializer: 'zeros' + }) + .apply(y) as tfl.SymbolicTensor[]; const model = tfl.model({inputs: inp, outputs: y}); - const xs = tensor2d( - [[0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [1, 2, 0, 0, 0, 0], - [1, 2, 3, 0, 0, 0]]); + const xs = tensor2d([ + [0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [1, 2, 0, 0, 0, 0], + [1, 2, 3, 0, 0, 0] + ]); const ys = model.predict(xs) as Tensor[]; expect(ys.length).toEqual(3); expectTensorsClose(ys[0], tensor2d([ - [0, 0, 0], - [0.76131237, 0.76131237, 0.76131237], - [0.9639796, 0.9639796, 0.9639796], - [0.99504817, 0.99504817, 0.99504817]])); + [0, 0, 0], [0.76131237, 0.76131237, 0.76131237], + [0.9639796, 0.9639796, 0.9639796], + [0.99504817, 0.99504817, 0.99504817] + ])); expectTensorsClose(ys[1], tensor2d([ - [0, 0, 0], - [0.76131237, 0.76131237, 0.76131237], - [0.9639796, 0.9639796, 0.9639796], - [0.99504817, 0.99504817, 0.99504817]])); - expectTensorsClose(ys[2], tensor2d([ - [0, 0, 0], - [0.9993292, 0.9993292, 0.9993292], - [1.9993222, 1.9993222, 1.9993222], - [2.9993203,2.9993203, 2.9993203]])); + [0, 0, 0], [0.76131237, 0.76131237, 0.76131237], + [0.9639796, 0.9639796, 0.9639796], + [0.99504817, 0.99504817, 0.99504817] + ])); + expectTensorsClose( + ys[2], tensor2d([ + [0, 0, 0], [0.9993292, 0.9993292, 0.9993292], + [1.9993222, 1.9993222, 1.9993222], [2.9993203, 2.9993203, 2.9993203] + ])); }); // Referernce Python code: @@ -2438,16 +2605,16 @@ describeMathCPUAndGPU('LSTM Tensor', () => { kernelInitializer: 'ones', biasInitializer: 'zeros' })); - model.add(tfl.layers.dense({ - units: 1, kernelInitializer: 'ones', biasInitializer: 'zeros'})); + model.add(tfl.layers.dense( + {units: 1, kernelInitializer: 'ones', biasInitializer: 'zeros'})); - const xs = tensor2d( - [[0, 0, 0], [1, 0, 0], [1, 2, 0], [1, 2, 3]]); + const xs = tensor2d([[0, 0, 0], [1, 0, 0], [1, 2, 0], [1, 2, 3]]); const ys = model.predict(xs) as Tensor; - expectTensorsClose(ys, tensor3d( - [[[0], [0], [0]], [[2.283937], [2.283937], [2.283937]], - [[2.283937], [2.8919387], [2.8919387]], - [[2.283937], [2.8919387], [2.9851446]]])); + expectTensorsClose(ys, tensor3d([ + [[0], [0], [0]], [[2.283937], [2.283937], [2.283937]], + [[2.283937], [2.8919387], [2.8919387]], + [[2.283937], [2.8919387], [2.9851446]] + ])); }); // Reference Python code: @@ -2501,8 +2668,7 @@ describeMathCPUAndGPU('LSTM Tensor', () => { biasInitializer: 'zeros' })); - const xs = tensor2d( - [[0, 0, 0], [1, 0, 0], [1, 2, 0], [1, 2, 3]]); + const xs = tensor2d([[0, 0, 0], [1, 0, 0], [1, 2, 0], [1, 2, 3]]); // Burn-in call for subsequent memory leak check. model.predict(xs); @@ -2510,10 +2676,11 @@ describeMathCPUAndGPU('LSTM Tensor', () => { const numTensors0 = tfc.memory().numTensors; const ys = model.predict(xs) as Tensor; const numTensors1 = tfc.memory().numTensors; - expectTensorsClose(ys, tensor2d( - [[0, 0, 0], [0.75950104, 0.75950104, 0.75950104], - [0.96367145, 0.96367145, 0.96367145], - [0.9950049, 0.9950049, 0.9950049]])); + expectTensorsClose(ys, tensor2d([ + [0, 0, 0], [0.75950104, 0.75950104, 0.75950104], + [0.96367145, 0.96367145, 0.96367145], + [0.9950049, 0.9950049, 0.9950049] + ])); ys.dispose(); // Assert no memory leak. expect(numTensors1).toEqual(numTensors0 + 1); @@ -2539,8 +2706,8 @@ describeMathCPU('LSTM-deserialization', () => { const layerPrime = tfl.layers.lstm(tsConfig) as LSTM; const yPrime = layer.apply(x) as Tensor; expectTensorsClose(yPrime, y); - expect(layerPrime.recurrentActivation.getClassName()) - .toEqual(layer.recurrentActivation.getClassName()); + expect(layerPrime.getConfig()['recurrentActivation']) + .toEqual(layer.getConfig()['recurrentActivation']); }); it('Non-default recurrentActivation round trip', () => { @@ -2554,8 +2721,8 @@ describeMathCPU('LSTM-deserialization', () => { const layerPrime = tfl.layers.lstm(tsConfig) as LSTM; const yPrime = layer.apply(x) as Tensor; expectTensorsClose(yPrime, y); - expect(layerPrime.recurrentActivation.getClassName()) - .toEqual(layer.recurrentActivation.getClassName()); + expect(layerPrime.getConfig()['recurrentActivation']) + .toEqual(layer.getConfig()['recurrentActivation']); }); }); @@ -2680,9 +2847,9 @@ const fakeLSTMModel: ModelAndWeightsConfig = { 'return_sequences': true, 'recurrent_constraint': null }, - 'inbound_nodes': [[ - ['input_2', 0, 0, {}], ['lstm_1', 0, 1, {}], ['lstm_1', 0, 2, {}] - ]], + 'inbound_nodes': [ + [['input_2', 0, 0, {}], ['lstm_1', 0, 1, {}], ['lstm_1', 0, 2, {}]] + ], 'name': 'lstm_2' }, { @@ -2721,6 +2888,104 @@ const fakeLSTMModel: ModelAndWeightsConfig = { } }; +describeMathCPU('LSTM Serialization', () => { + const cellConfig: LSTMCellLayerArgs = { + units: 8, + activation: 'tanh', + recurrentActivation: 'tanh', + useBias: true, + kernelInitializer: 'glorotUniform', + recurrentInitializer: 'heUniform', + biasInitializer: 'ones', + kernelRegularizer: 'l1l2', + recurrentRegularizer: 'l1l2', + biasRegularizer: 'l1l2', + kernelConstraint: 'unitNorm', + recurrentConstraint: 'unitNorm', + biasConstraint: 'nonNeg', + dropout: 0.1, + recurrentDropout: 0.2, + name: 'cell_1', + batchSize: 12, + batchInputShape: [12, 8, 8], + inputShape: [8, 8], + dtype: 'int32', + inputDType: 'int32', + trainable: true, + implementation: 1, + unitForgetBias: true, + }; + + const expectedCellConfigPrime = { + name: 'cell_1', + trainable: true, + batchInputShape: [12, 8, 8], + dtype: 'int32', + units: 8, + activation: serializeActivation(new Tanh()), + recurrentActivation: serializeActivation(new Tanh()), + useBias: true, + kernelInitializer: serializeInitializer(new GlorotUniform()), + recurrentInitializer: serializeInitializer(new HeUniform()), + biasInitializer: serializeInitializer(new Ones()), + kernelRegularizer: serializeRegularizer(new L1L2()), + recurrentRegularizer: serializeRegularizer(new L1L2()), + biasRegularizer: serializeRegularizer(new L1L2()), + activityRegularizer: serializeRegularizer(null), + kernelConstraint: serializeConstraint(new UnitNorm({})), + recurrentConstraint: serializeConstraint(new UnitNorm({})), + biasConstraint: serializeConstraint(new NonNeg()), + implementation: 1, + unitForgetBias: true, + }; + + describe('LSTMCell.getConfig', () => { + it('should return the expected values', () => { + const cell = tfl.layers.lstmCell(cellConfig); + + const {dropout, recurrentDropout, ...configPrime} = cell.getConfig(); + + expect(configPrime).toEqual(expectedCellConfigPrime); + expect(dropout).toBeCloseTo(0.1); + expect(recurrentDropout).toBeCloseTo(0.2); + }); + }); + + describe('LSTM.getConfig', () => { + it('should return the expected values', () => { + const config: LSTMLayerArgs = { + ...cellConfig, + name: 'layer_1', + ...{ + returnSequences: true, + returnState: true, + stateful: true, + unroll: true, + goBackwards: true, + inputDim: 8, + inputLength: 8, + } as Omit + }; + + const cell = tfl.layers.lstm(config); + + const {dropout, recurrentDropout, ...configPrime} = cell.getConfig(); + + expect(configPrime).toEqual({ + ...expectedCellConfigPrime, + name: 'layer_1', + returnSequences: true, + returnState: true, + stateful: true, + unroll: true, + goBackwards: true, + }); + expect(dropout).toBeCloseTo(0.1); + expect(recurrentDropout).toBeCloseTo(0.2); + }); + }); +}); + describeMathCPU('StackedRNNCells Symbolic', () => { it('With SimpleRNNCell', () => { const stackedRNN = tfl.layers.rnn({ @@ -2756,8 +3021,7 @@ describeMathCPU('StackedRNNCells Symbolic', () => { cell: tfl.layers.stackedRNNCells({ cells: [ tfl.layers.lstmCell({units: 3, recurrentInitializer: 'glorotNormal'}), - tfl.layers.lstmCell( - {units: 2, recurrentInitializer: 'glorotNormal'}) + tfl.layers.lstmCell({units: 2, recurrentInitializer: 'glorotNormal'}) ], }) }); @@ -2810,22 +3074,13 @@ describeMathCPU('StackedRNNCells Symbolic', () => { describeMathCPU('Stacked RNN serialization', () => { it('StackedRNNCells', async () => { const model = tfl.sequential(); - model.add(tfl.layers.dense({ - units: 1, - inputShape: [3, 4], - kernelInitializer: 'ones' - })); + model.add(tfl.layers.dense( + {units: 1, inputShape: [3, 4], kernelInitializer: 'ones'})); const cells = [ - tfl.layers.lstmCell({ - units: 5, - kernelInitializer: 'ones', - recurrentInitializer: 'ones' - }), - tfl.layers.lstmCell({ - units: 6, - kernelInitializer: 'ones', - recurrentInitializer: 'ones' - }) + tfl.layers.lstmCell( + {units: 5, kernelInitializer: 'ones', recurrentInitializer: 'ones'}), + tfl.layers.lstmCell( + {units: 6, kernelInitializer: 'ones', recurrentInitializer: 'ones'}) ]; const rnn = tfl.layers.rnn({cell: cells, returnSequences: true}); model.add(rnn);