New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix KeyError when validation_data was given as a dict #30258
Fix KeyError when validation_data was given as a dict #30258
Conversation
This fix tries to address the issue raised in 30122 where a KeyError was thrown when validation_data was given as a dict during the mode.fit. This fix fixes the issue. Thisfix fixes 30122. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@@ -207,7 +207,8 @@ def model_iteration(model, | |||
val_samples_or_steps = validation_steps | |||
else: | |||
# Get num samples for printing. | |||
val_samples_or_steps = val_inputs and val_inputs[0].shape[0] or None | |||
vals = val_inputs.values() if isinstance(val_inputs, dict) else val_inputs | |||
val_samples_or_steps = vals and vals[0].shape[0] or None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably skip the dict check and use nest.flatten(vals)[0].shape[0] here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @qlzh727, the PR has been updated.
@@ -110,6 +110,43 @@ def test_print_info_with_numpy(self, do_validation): | |||
if do_validation: | |||
self.assertIn(", validate on 50 samples", mock_stdout.getvalue()) | |||
|
|||
def test_dict_input(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be test_dict_validation_input()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
model = my_model() | ||
model.compile(loss="mae", optimizer="adam") | ||
|
||
mock_stdout = six.StringIO() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the reason to mock the output, are u trying to validate any output here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. Removed.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Thanks @qlzh727 for the review. The PR has been updated. Please take a look and let me know if there are any other issues. |
Ping @qlzh727, any chance to take a look at the PR? |
PiperOrigin-RevId: 260005639
This fix tries to address the issue raised in #30122 where a KeyError was thrown when validation_data was given as a dict during the mode.fit. This fix fixes the issue.
This fix fixes #30122.
Signed-off-by: Yong Tang yong.tang.github@outlook.com