diff --git a/e2e/integration_tests/constants.ts b/e2e/integration_tests/constants.ts index be5caab154c..1993110f78c 100644 --- a/e2e/integration_tests/constants.ts +++ b/e2e/integration_tests/constants.ts @@ -28,10 +28,8 @@ export const BACKENDS = ['cpu', 'webgl']; /** Testing models for CUJ: create -> save -> predict. */ export const LAYERS_MODELS = [ - // (TODO: piyu) Enable this test once gru weight shape bug is fixed. - 'mlp', 'cnn', 'depthwise_cnn', 'simple_rnn', //'gru', - 'bidirectional_lstm', 'time_distributed_lstm', 'one_dimensional', - 'functional_merge' + 'mlp', 'cnn', 'depthwise_cnn', 'simple_rnn', 'gru', 'bidirectional_lstm', + 'time_distributed_lstm', 'one_dimensional', 'functional_merge' ]; export const GRAPH_MODELS = [ diff --git a/e2e/integration_tests/convert_predict.py b/e2e/integration_tests/convert_predict.py index 363b02698f3..40c0612f94f 100644 --- a/e2e/integration_tests/convert_predict.py +++ b/e2e/integration_tests/convert_predict.py @@ -103,7 +103,7 @@ def _save_and_convert_model(model_fn, model_path, control_flow_v2=False): if control_flow_v2: args = args + ['--control_flow_v2', 'True'] - print(args) + print(args, tmp_saved_model_dir, artifacts_dir) subprocess.check_output(args +[tmp_saved_model_dir, artifacts_dir]) def _create_saved_model_v1(save_dir): @@ -282,7 +282,7 @@ def _create_saved_model_v2_complex64(save_dir): "Identity:0": {"value": [4, 2], "shape": [1], "dtype": "complex64"}}} def _create_saved_model_v2_with_control_flow_v2(save_dir): - """Test a TF V2 model with complex dtype. + """Test a TF V2 model with control flow v2. Args: save_dir: directory name of where the saved model will be stored. @@ -296,7 +296,10 @@ def __init__(self): tf.TensorSpec([], tf.float32), tf.TensorSpec([], tf.float32)]) def control_flow(self, x, y): while x < y: - x = x + 2 + if y > 0: + x = x + y + else: + x = x + 2 return x diff --git a/e2e/integration_tests/create_save_predict.py b/e2e/integration_tests/create_save_predict.py index 422e639a043..32c240257d1 100644 --- a/e2e/integration_tests/create_save_predict.py +++ b/e2e/integration_tests/create_save_predict.py @@ -87,8 +87,7 @@ def main(): _load_predict_save('cnn') _load_predict_save('depthwise_cnn') _load_predict_save('simple_rnn') - #(TODO: piyu) Enable this test once gru weight shape bug is fixed. - #_load_predict_save('gru') + _load_predict_save('gru') _load_predict_save('bidirectional_lstm') _load_predict_save('time_distributed_lstm') _load_predict_save('one_dimensional') diff --git a/tfjs-layers/src/layers/recurrent.ts b/tfjs-layers/src/layers/recurrent.ts index c368582a6bf..41881efbf8c 100644 --- a/tfjs-layers/src/layers/recurrent.ts +++ b/tfjs-layers/src/layers/recurrent.ts @@ -1335,6 +1335,13 @@ export declare interface GRUCellLayerArgs extends SimpleRNNCellLayerArgs { * 2, regardless of the actual value of this configuration field. */ implementation?: number; + + /** + * GRU convention (whether to apply reset gate after or before matrix + * multiplication). false = "before", true = "after" (only false is + * supported). + */ + resetAfter?: boolean; } export class GRUCell extends RNNCell { @@ -1376,7 +1383,10 @@ export class GRUCell extends RNNCell { constructor(args: GRUCellLayerArgs) { super(args); - + if (args.resetAfter) { + throw new ValueError( + `GRUCell does not support reset_after parameter set to true.`); + } this.units = args.units; assertPositiveInteger(this.units, 'units'); this.activation = getActivation( @@ -1525,6 +1535,7 @@ export class GRUCell extends RNNCell { dropout: this.dropout, recurrentDropout: this.recurrentDropout, implementation: this.implementation, + resetAfter: false }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); @@ -1676,6 +1687,7 @@ export class GRU extends RNN { dropout: this.dropout, recurrentDropout: this.recurrentDropout, implementation: this.implementation, + resetAfter: false }; const baseConfig = super.getConfig(); delete baseConfig['cell'];