Skip to content

Cant save keras RNN model with custom cell whose call function accepts constants #43369

Closed
@juulie

Description

@juulie

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):Windows 10
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary):binary
  • TensorFlow version (use command below):2.3.0
  • Python version:3.8.5
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:

Describe the current behavior

When creating an RNN model with custom cells that accept constants. saving the model with the default SavedModel format will raise a ValueError "RNN cell does not support constants". Using the h5 format does work

Describe the expected behavior

I dont expect the ValueError, since the call function does support the constants argument

Standalone code to reproduce the issue

I've re-used the custom cell sample from https://keras.io/guides/working_with_rnns/#rnns-with-listdict-inputs-or-nested-inputs
but just added the constants argument:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


class NestedCell(keras.layers.Layer):
    def __init__(self, unit_1, unit_2, unit_3, **kwargs):
        self.unit_1 = unit_1
        self.unit_2 = unit_2
        self.unit_3 = unit_3
        self.state_size = [tf.TensorShape([unit_1]), tf.TensorShape([unit_2, unit_3])]
        self.output_size = [tf.TensorShape([unit_1]), tf.TensorShape([unit_2, unit_3])]
        super(NestedCell, self).__init__(**kwargs)

    def build(self, input_shapes):
        # expect input_shape to contain 2 items, [(batch, i1), (batch, i2, i3)]
        i1 = input_shapes[0][1]
        i2 = input_shapes[1][1]
        i3 = input_shapes[1][2]

        self.kernel_1 = self.add_weight(
            shape=(i1, self.unit_1), initializer="uniform", name="kernel_1"
        )
        self.kernel_2_3 = self.add_weight(
            shape=(i2, i3, self.unit_2, self.unit_3),
            initializer="uniform",
            name="kernel_2_3",
        )

    def call(self, inputs, states, constants):
        # inputs should be in [(batch, input_1), (batch, input_2, input_3)]
        # state should be in shape [(batch, unit_1), (batch, unit_2, unit_3)]
        input_1, input_2 = tf.nest.flatten(inputs)
        s1, s2 = states

        output_1 = tf.matmul(input_1, self.kernel_1)
        output_2_3 = tf.einsum("bij,ijkl->bkl", input_2, self.kernel_2_3)
        state_1 = s1 + output_1
        state_2_3 = s2 + output_2_3

        output = (output_1, output_2_3)
        new_states = (state_1, state_2_3)

        return output, new_states

    def get_config(self):
        return {"unit_1": self.unit_1, "unit_2": unit_2, "unit_3": self.unit_3}

unit_1 = 10
unit_2 = 20
unit_3 = 30

i1 = 32
i2 = 64
i3 = 32
batch_size = 64
num_batches = 10
timestep = 50


input_1 = keras.Input((None, i1))
input_2 = keras.Input((None, i2, i3))
input_const = keras.Input((None, i1))
cell = NestedCell(unit_1, unit_2, unit_3)
outputs = keras.layers.RNN(cell = cell)(inputs = (input_1, input_2), constants=input_const)

model = keras.models.Model([input_1, input_2, input_const], outputs)

model.compile(optimizer="adam", loss="mse", metrics=["accuracy"])
model.save("test")

Metadata

Metadata

Assignees

Labels

TF 2.11Issues related to TF 2.11comp:kerasKeras related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authortype:bugBug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions