Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Named dictionary outputs in tf.keras.Model do not work #34199

Closed
kpe opened this issue Nov 12, 2019 · 7 comments
Closed

Named dictionary outputs in tf.keras.Model do not work #34199

kpe opened this issue Nov 12, 2019 · 7 comments
Assignees
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug

Comments

@kpe
Copy link

kpe commented Nov 12, 2019

System information

  • Have I written custom code: Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): any
  • TensorFlow installed from (source or binary): source
  • TensorFlow version (use command below): 2.0.0
  • Python version: 3.6.5

Describe the current behavior
Using a custom model with named outputs does not work in TF2.

Describe the expected behavior
While using tuples for multi output model works fine, using a dictionary fails. Dict inputs and outputs seem to be allowed in the code (and in the tf.keras documentation), however the functionality seem not to be functional yet.

Code to reproduce the issue
https://colab.research.google.com/gist/kpe/501901b5197675818a2e8a0e0bc8f3a6/keras-named-output-dict.ipynb

%tensorflow_version 2.x
import tensorflow as tf


max_seq_len    = 8
channels_count = 11

class MultiOutputModel(tf.keras.Model):
    def __init__(self):
        super(MultiOutputModel, self).__init__()
        self.dense_a = tf.keras.layers.Dense(3)
        self.dense_b = tf.keras.layers.Dense(4)
        
    def call(self, inputs):
        seq = inputs["F"]
        out_a = self.dense_a(seq)
        out_b = self.dense_b(seq)
        return {"A": out_a, "B": out_b}
    
def ds_gen():
    while True:
        inputs  = {"F": tf.random.uniform((max_seq_len, channels_count))}
        outputs = {"A": tf.random.uniform((), minval=0, maxval=3, dtype=tf.int32), 
                   "B": tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32)}
        yield inputs, outputs
        
ds = tf.data.Dataset.from_generator(ds_gen, 
                                    output_types=({"F": tf.float32}, 
                                                  {"A": tf.int32, "B":tf.int32}), 
                                    output_shapes=({"F": tf.TensorShape([max_seq_len, channels_count])}, 
                                                   {"A":tf.TensorShape([]), "B":tf.TensorShape([])}))
# check dataset - a (features, labels) tuple
for inp, out in ds.batch(8).take(1):
    for ndx, (name, val) in enumerate(inp.items()):
        print("features {}: {}: {}".format(ndx, name, val.shape), val.dtype)
    for ndx, (name, val) in enumerate(out.items()):
        print("  labels {}: {}: {}".format(ndx, name, val.shape), val.dtype)
    
model = MultiOutputModel()

def features_only(feat, lab):
    return feat

pred = model.predict(ds.map(features_only).batch(8).take(1))
print(" prediction:", type(pred))
for ndx, out in enumerate(pred):
    print(" pred out {}: {}".format(ndx, out.shape), out.dtype)

model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss={"A": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                    "B": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),})

model.predict({"seq": tf.constant([[2],[1]])})
model.fit(ds.batch(1))

The output from the above code is:

features 0: F: (8, 8, 11) <dtype: 'float32'>
  labels 0: A: (8,) <dtype: 'int32'>
  labels 1: B: (8,) <dtype: 'int32'>
 prediction: <class 'list'>
 pred out 0: (8, 8, 3) float32
 pred out 1: (8, 8, 4) float32
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-28-948267c0fc05> in <module>()
     49 model.compile(optimizer=tf.keras.optimizers.Adam(),
     50               loss={"A": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
---> 51                     "B": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),})
     52 
     53 model.fit(ds.batch(8))

3 frames
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/utils/generic_utils.py in check_for_unexpected_keys(name, input_dict, expected_values)
    589     raise ValueError('Unknown entries in {} dictionary: {}. Only expected '
    590                      'following keys: {}'.format(name, list(unknown),
--> 591                                                  expected_values))
    592 
    593 

ValueError: Unknown entries in loss dictionary: ['A', 'B']. Only expected following keys: ['output_1', 'output_2']
@ravikyram ravikyram self-assigned this Nov 13, 2019
@ravikyram ravikyram added comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 labels Nov 13, 2019
@ravikyram
Copy link
Contributor

@kpe

I tried reproducing the issue in colab with TF 2.0 . I am seeing the below error message.
AttributeError: 'str' object has no attribute 'dtype' . Is this the expected behavior?. Thanks!

@ravikyram ravikyram added the stat:awaiting response Status - Awaiting response from author label Nov 13, 2019
@kpe
Copy link
Author

kpe commented Nov 19, 2019

@ravikyram - I update the example and added a colab link for easier reproducibility.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Nov 20, 2019
@ravikyram
Copy link
Contributor

I have tried on colab with TF version 2.0 ,2.1.0-dev20191119 and was able to reproduce the issue.Please, find the gist here. Thanks!

@ravikyram ravikyram added the type:bug Bug label Nov 20, 2019
@galfridman
Copy link

galfridman commented Nov 27, 2019

I would also like to add that we could have resolved this issue with a simple workaround by simply modifying the model.output_names and model._output_layers (in tensorflow 1.15 and below)

e.g:

class CustomLayer()

  def __init__(self, ...):
     self.nested_layer_1 = tf.keras.layers.Dense(...)
     self.nested_layer_2 = tf.keras.layers.Dense(...)
     # for each nested layer that produces output we add it to a 'nested_layers' dict mapping names to layers
     self.nested_layers_dict[nested_layer_1.name] = self.nested_layer_1
     self.nested_layers_dict[nested_layer_2.name] = self.nested_layer_2
      ...

  def call(self, inputs):
      ...
     outputs_dict[self.nested_layer1.name] = self.nested_layer_1(inputs)
     outputs_dict[self.nested_layer2.name] = self.nested_layer_2(inputs)
     return outputs_dict

def __main__():
  ...
  inputs = [...]
  custom_layer =CustomLayer()
  outputs = custom_layer(inputs)
  model = tf.keras.models.Model(inputs=inputs, outputs=outputs )
  model.output_names = list(custom_layer.output.keys())
  model._output_layers = [custom_layer.nested_layers_dict[output_name] for output_name in model.output_names]

Now when we do that on tensorflow 2.0: after changing the _output_layers the model adds the nested layers into itself and creates extra layers on top of the original model.

@galfridman
Copy link

Any news regarding the bug?

@omalleyt12
Copy link
Contributor

Thanks for the issue! Support for arbitrary nested structures (including dicts) is available in the latest tf-nightly: pip install -U tf-nightly.

I think there are some issues in the code provided regarding the loss used (I don't think the Model is outputting the shape of data that SparseCategoricalAccuracy expects), but confirmed that prediction is working as expected.

Also see this bug: #33245

import tensorflow as tf


max_seq_len    = 8
channels_count = 11

class MultiOutputModel(tf.keras.Model):
    def __init__(self):
        super(MultiOutputModel, self).__init__()
        self.dense_a = tf.keras.layers.Dense(3)
        self.dense_b = tf.keras.layers.Dense(4)
        
    def call(self, inputs):
        seq = inputs["F"]
        out_a = self.dense_a(seq)
        out_b = self.dense_b(seq)
        return {"A": out_a, "B": out_b}
    
def ds_gen():
    while True:
        inputs  = {"F": tf.random.uniform((max_seq_len, channels_count))}
        outputs = {"A": tf.random.uniform((), minval=0, maxval=3, dtype=tf.int32), 
                   "B": tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32)}
        yield inputs, outputs
        
ds = tf.data.Dataset.from_generator(ds_gen, 
                                    output_types=({"F": tf.float32}, 
                                                  {"A": tf.int32, "B":tf.int32}), 
                                    output_shapes=({"F": tf.TensorShape([max_seq_len, channels_count])}, 
                                                   {"A":tf.TensorShape([]), "B":tf.TensorShape([])}))
# check dataset - a (features, labels) tuple
for inp, out in ds.batch(8).take(1):
    for ndx, (name, val) in enumerate(inp.items()):
        print("features {}: {}: {}".format(ndx, name, val.shape), val.dtype)
    for ndx, (name, val) in enumerate(out.items()):
        print("  labels {}: {}: {}".format(ndx, name, val.shape), val.dtype)
    
model = MultiOutputModel()

def features_only(feat, lab):
    return feat

pred = model.predict(ds.map(features_only).batch(8).take(1))

@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug
Projects
None yet
Development

No branches or pull requests

6 participants