Closed
Description
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):Windows 10
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):binary
- TensorFlow version (use command below):2.3.0
- Python version:3.8.5
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
Describe the current behavior
When creating an RNN model with custom cells that accept constants. saving the model with the default SavedModel format will raise a ValueError "RNN cell does not support constants". Using the h5 format does work
Describe the expected behavior
I dont expect the ValueError, since the call function does support the constants argument
Standalone code to reproduce the issue
I've re-used the custom cell sample from https://keras.io/guides/working_with_rnns/#rnns-with-listdict-inputs-or-nested-inputs
but just added the constants argument:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class NestedCell(keras.layers.Layer):
def __init__(self, unit_1, unit_2, unit_3, **kwargs):
self.unit_1 = unit_1
self.unit_2 = unit_2
self.unit_3 = unit_3
self.state_size = [tf.TensorShape([unit_1]), tf.TensorShape([unit_2, unit_3])]
self.output_size = [tf.TensorShape([unit_1]), tf.TensorShape([unit_2, unit_3])]
super(NestedCell, self).__init__(**kwargs)
def build(self, input_shapes):
# expect input_shape to contain 2 items, [(batch, i1), (batch, i2, i3)]
i1 = input_shapes[0][1]
i2 = input_shapes[1][1]
i3 = input_shapes[1][2]
self.kernel_1 = self.add_weight(
shape=(i1, self.unit_1), initializer="uniform", name="kernel_1"
)
self.kernel_2_3 = self.add_weight(
shape=(i2, i3, self.unit_2, self.unit_3),
initializer="uniform",
name="kernel_2_3",
)
def call(self, inputs, states, constants):
# inputs should be in [(batch, input_1), (batch, input_2, input_3)]
# state should be in shape [(batch, unit_1), (batch, unit_2, unit_3)]
input_1, input_2 = tf.nest.flatten(inputs)
s1, s2 = states
output_1 = tf.matmul(input_1, self.kernel_1)
output_2_3 = tf.einsum("bij,ijkl->bkl", input_2, self.kernel_2_3)
state_1 = s1 + output_1
state_2_3 = s2 + output_2_3
output = (output_1, output_2_3)
new_states = (state_1, state_2_3)
return output, new_states
def get_config(self):
return {"unit_1": self.unit_1, "unit_2": unit_2, "unit_3": self.unit_3}
unit_1 = 10
unit_2 = 20
unit_3 = 30
i1 = 32
i2 = 64
i3 = 32
batch_size = 64
num_batches = 10
timestep = 50
input_1 = keras.Input((None, i1))
input_2 = keras.Input((None, i2, i3))
input_const = keras.Input((None, i1))
cell = NestedCell(unit_1, unit_2, unit_3)
outputs = keras.layers.RNN(cell = cell)(inputs = (input_1, input_2), constants=input_const)
model = keras.models.Model([input_1, input_2, input_const], outputs)
model.compile(optimizer="adam", loss="mse", metrics=["accuracy"])
model.save("test")