Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions tfjs-layers/src/layers/convolutional_recurrent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -287,15 +287,6 @@ class ConvRNN2D extends RNN {
});
}

getConfig(): tfc.serialization.ConfigDict {
const {'cell': _, ...config} = super.getConfig();

const cellConfig = this.cell.getConfig();

// this order is necessary, to prevent cell name from replacing layer name
return {...cellConfig, ...config};
}

protected computeSingleOutputShape(inputShape: Shape): Shape {
const {dataFormat, filters, kernelSize, padding, strides, dilationRate} =
this.cell;
Expand Down Expand Up @@ -532,6 +523,7 @@ export class ConvLSTM2DCell extends LSTMCell implements ConvRNN2DCell {
padding: this.padding,
dataFormat: this.dataFormat,
dilationRate: this.dilationRate,
strides: this.strides,
};

return {...baseConfig, ...config};
Expand Down
112 changes: 111 additions & 1 deletion tfjs-layers/src/layers/convolutional_recurrent_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@
*/

import * as tfc from '@tensorflow/tfjs-core';
import {serializeActivation, Tanh} from '../activations';
import {NonNeg, serializeConstraint, UnitNorm} from '../constraints';

import {sequential} from '../exports';
import * as tfl from '../index';
import {GlorotUniform, HeUniform, Ones, serializeInitializer} from '../initializers';
import {DataFormat, PaddingMode} from '../keras_format/common';
import {modelFromJSON} from '../models';
import {L1L2, serializeRegularizer} from '../regularizers';
import {getCartesianProductOfValues} from '../utils/generic_utils';
import {convertPythonicToTs, convertTsToPythonic} from '../utils/serialization_utils';
import {describeMathCPU, describeMathCPUAndGPU, expectTensorsClose} from '../utils/test_utils';

import {ConvLSTM2DArgs} from './convolutional_recurrent';
import {ConvLSTM2DArgs, ConvLSTM2DCellArgs} from './convolutional_recurrent';

describeMathCPUAndGPU('ConvLSTM2DCell', () => {
/**
Expand Down Expand Up @@ -737,6 +741,112 @@ describeMathCPU('should run BPTT correctly', () => {
});

describeMathCPU('ConvLSTM2D Serialization and Deserialization', () => {
const cellConfig: ConvLSTM2DCellArgs = {
filters: 32,
kernelSize: 3,
dataFormat: 'channelsLast',
dilationRate: 2,
padding: 'same',
strides: 2,
activation: 'tanh',
recurrentActivation: 'tanh',
useBias: true,
kernelInitializer: 'glorotUniform',
recurrentInitializer: 'heUniform',
biasInitializer: 'ones',
kernelRegularizer: 'l1l2',
recurrentRegularizer: 'l1l2',
biasRegularizer: 'l1l2',
kernelConstraint: 'unitNorm',
recurrentConstraint: 'unitNorm',
biasConstraint: 'nonNeg',
dropout: 0.1,
recurrentDropout: 0.2,
name: 'cell_1',
batchSize: 12,
batchInputShape: [12, 8, 8],
inputShape: [8, 8],
dtype: 'int32',
inputDType: 'int32',
trainable: true,
implementation: 1,
unitForgetBias: true,
};

const expectedCellConfigPrime = {
name: 'cell_1',
trainable: true,
batchInputShape: [12, 8, 8],
dtype: 'int32',
filters: 32,
kernelSize: [3, 3],
dataFormat: 'channelsLast',
dilationRate: [2, 2],
padding: 'same',
strides: [2, 2],
activation: serializeActivation(new Tanh()),
recurrentActivation: serializeActivation(new Tanh()),
useBias: true,
kernelInitializer: serializeInitializer(new GlorotUniform()),
recurrentInitializer: serializeInitializer(new HeUniform()),
biasInitializer: serializeInitializer(new Ones()),
kernelRegularizer: serializeRegularizer(new L1L2()),
recurrentRegularizer: serializeRegularizer(new L1L2()),
biasRegularizer: serializeRegularizer(new L1L2()),
activityRegularizer: serializeRegularizer(null),
kernelConstraint: serializeConstraint(new UnitNorm({})),
recurrentConstraint: serializeConstraint(new UnitNorm({})),
biasConstraint: serializeConstraint(new NonNeg()),
implementation: 1,
unitForgetBias: true,
};

describe('ConvLSTM2DCell.getConfig', () => {
it('should return the expected values', () => {
const cell = tfl.layers.convLstm2dCell(cellConfig);

const {dropout, recurrentDropout, ...configPrime} = cell.getConfig();

expect(configPrime).toEqual(expectedCellConfigPrime);
expect(dropout).toBeCloseTo(0.1);
expect(recurrentDropout).toBeCloseTo(0.2);
});
});

describe('ConvLSTM2D.getConfig', () => {
it('should return the expected values', () => {
const config: ConvLSTM2DArgs = {
...cellConfig,
name: 'layer_1',
...{
returnSequences: true,
returnState: true,
stateful: true,
unroll: false,
goBackwards: true,
inputDim: 8,
inputLength: 8,
} as Omit<ConvLSTM2DArgs, keyof ConvLSTM2DCellArgs>
};

const cell = tfl.layers.convLstm2d(config);

const {dropout, recurrentDropout, ...configPrime} = cell.getConfig();

expect(configPrime).toEqual({
...expectedCellConfigPrime,
name: 'layer_1',
returnSequences: true,
returnState: true,
stateful: true,
unroll: false,
goBackwards: true,
});
expect(dropout).toBeCloseTo(0.1);
expect(recurrentDropout).toBeCloseTo(0.2);
});
});

it('should return equal outputs before and after', async () => {
const model = sequential();

Expand Down
Loading