Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras fails to account for smaller last batch in loss metric calculation #38596

Closed
bentyeh opened this issue Apr 16, 2020 · 6 comments
Closed
Assignees
Labels
comp:keras Keras related issues regression issue To spot regression issues in latest version stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.2 Issues related to TF 2.2 type:bug Bug

Comments

@bentyeh
Copy link

bentyeh commented Apr 16, 2020

System information

  • OS Platform and Distribution: Google Colab Notebook
  • TensorFlow version (use command below): 2.2.0-rc2
  • Python version: 3.6

Describe the current behavior

In tf.keras models, the model.test_step() method (which is called by model.fit() and model.evaluate()) incorrectly computes the mean loss over all batches in an epoch when the dataset size is not evenly divisible by the batch size. This applies for both training and validation loss. This bug affects the reported epoch loss, but NOT the training loss used for computing gradient updates.

Currently, TensorFlow-Keras computes the loss for each batch, adds together the losses across batches, then divides by the number of batches. In other words, the reported loss at the end of each epoch is (incorrectly) unweighted with respect to the size of each batch.

For example, suppose there are 3 samples in a dataset, and the batch size is 2. Then there are 2 batches of size 2 and 1. If the first batch has mean loss of 10 and the second batch has mean loss of 9, then the mean loss over the entire dataset is currently (incorrectly) computed as (10 + 9) / 2 = 9.5.

Describe the expected behavior

Continuing with the example above, the correct mean loss over the dataset should be a weighted mean of the batch losses, where the weights are given by each batch size. Thus, the correct mean loss should be (10*2 + 9*1) / (2 + 1) = 9.66666. This is shown in the code below.

Standalone code to reproduce the issue

Code (gist here)

import tensorflow as tf

X = tf.constant([[1],
                 [2],
                 [3]], dtype=tf.float32)
y = tf.constant([[5],
                 [4],
                 [6]], dtype=tf.float32)

# y_pred = a * x + b, where weights are intialized as a = 1, b = 0
# thus, MSE = (x - y)**2 / len(x)
model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_dim=1, kernel_initializer='ones', bias_initializer='zeros')])
model.compile(optimizer='sgd', loss='mean_squared_error')

def mse(y, y_pred):
    assert len(y) == len(y_pred)
    return sum((y - y_pred)**2)/len(y)

print('model.evaluate():')
print('- batch_size=1:', model.evaluate(X, y, batch_size=1, verbose=0))
print('- batch_size=2:', model.evaluate(X, y, batch_size=2, verbose=0))
print('- batch_size=3:', model.evaluate(X, y, batch_size=3, verbose=0))
print()

# incorrect mean of two different-sized batches
# Batch 1 is size 2, but Batch 2 is size 1
# So we should compute a weighted mean, but Tensorflow-Keras fails to do so
print((mse(X[:-1], y[:-1]) + mse(X[-1], y[-1]))/2)

Output

model.evaluate():
- batch_size=1: 9.666666984558105
- batch_size=2: 9.5
- batch_size=3: 9.666666984558105

tf.Tensor([9.5], shape=(1,), dtype=float32)

Where this error occurs in TensorFlow source code

The following line in the model.test_step() method calls the self.compiled_loss object.

self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)

self.compiled_loss is a compile_utils.LossesContainer object whose __call__() method seems to be implemented incorrectly. Specifically, the following line is where each batch's total loss is accumulated over an epoch, but the accumulation is done without any record of the batch size.

self._loss_metric.update_state(total_loss_metric_value)

Consequently, the mean epoch loss is calculated (m.result() below)

return {m.name: m.result() for m in self.metrics}

by dividing the total accumulated loss by the number of batches (self.count).

return math_ops.div_no_nan(self.total, self.count)

Proposed solution

I don't know what the best way to solve this problem may be, but the accumulation of each batch's loss should clearly track each batch's actual size. One possible solution may be to use the sample_weight argument and replace

self._loss_metric.update_state(total_loss_metric_value)

with

self._loss_metric.update_state(total_loss_metric_value, sample_weight=ACTUAL_BATCH_SIZE)

Related Issues

To the best of my knowledge, the problem described above is the root problem for a number of other reported issues: #35585 #35533 #38004 #38165

@chrisyeh96
Copy link
Contributor

Experiencing the same bug here. Can confirm and reproduce using the code shared above.

@HuggingLLM
Copy link

HuggingLLM commented Apr 16, 2020

I agree the loss should weight by batch size, but I don't think the big gaps loss between model.fit with GradientTape is cause by this. you can try with my code in #35585 (comment), I discover, if you want to get same metrics between GradientTape with model.fit , you should set GradientTape's epochs bigger than model.fit, like the epochs of model.fit is 10, the epochs of GradientTape is 100.

@ravikyram ravikyram added comp:keras Keras related issues TF 2.2 Issues related to TF 2.2 labels Apr 16, 2020
@ravikyram
Copy link
Contributor

I have tried on colab with TF version 2.2.0-rc2 and was able to reproduce the issue.Please, find the gist here. Thanks!

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Apr 16, 2020
@goldiegadde goldiegadde added the regression issue To spot regression issues in latest version label Apr 20, 2020
@goldiegadde goldiegadde added this to To do in TensorFlow 2.2.0 Apr 21, 2020
@goldiegadde goldiegadde moved this from To do to In progress in TensorFlow 2.2.0 Apr 23, 2020
@jvishnuvardhan
Copy link
Contributor

@chrisyeh96 This was resolved in recent tf-nightly. It will be available in stable TF2.2 in near future. Here is the gist for your reference. Thanks!

I am closing this issue as this was resolved. Please feel free to reopen if the issue persists again. Thanks!

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

TensorFlow 2.2.0 automation moved this from In progress to Done Apr 28, 2020
@bentyeh
Copy link
Author

bentyeh commented Apr 28, 2020

Note to self: Original fix commit to the master branch is 4f17f35

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues regression issue To spot regression issues in latest version stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.2 Issues related to TF 2.2 type:bug Bug
Projects
Development

No branches or pull requests

7 participants