New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras fails to account for smaller last batch in loss metric calculation #38596
Comments
Experiencing the same bug here. Can confirm and reproduce using the code shared above. |
I agree the loss should weight by batch size, but I don't think the big gaps loss between model.fit with GradientTape is cause by this. you can try with my code in #35585 (comment), I discover, if you want to get same metrics between GradientTape with model.fit , you should set GradientTape's epochs bigger than model.fit, like the epochs of model.fit is 10, the epochs of GradientTape is 100. |
I have tried on colab with TF version 2.2.0-rc2 and was able to reproduce the issue.Please, find the gist here. Thanks! |
@chrisyeh96 This was resolved in recent tf-nightly. It will be available in stable I am closing this issue as this was resolved. Please feel free to reopen if the issue persists again. Thanks! |
Note to self: Original fix commit to the master branch is 4f17f35 |
System information
Describe the current behavior
In
tf.keras
models, themodel.test_step()
method (which is called bymodel.fit()
andmodel.evaluate()
) incorrectly computes the mean loss over all batches in an epoch when the dataset size is not evenly divisible by the batch size. This applies for both training and validation loss. This bug affects the reported epoch loss, but NOT the training loss used for computing gradient updates.Currently, TensorFlow-Keras computes the loss for each batch, adds together the losses across batches, then divides by the number of batches. In other words, the reported loss at the end of each epoch is (incorrectly) unweighted with respect to the size of each batch.
For example, suppose there are 3 samples in a dataset, and the batch size is 2. Then there are 2 batches of size 2 and 1. If the first batch has mean loss of 10 and the second batch has mean loss of 9, then the mean loss over the entire dataset is currently (incorrectly) computed as
(10 + 9) / 2 = 9.5
.Describe the expected behavior
Continuing with the example above, the correct mean loss over the dataset should be a weighted mean of the batch losses, where the weights are given by each batch size. Thus, the correct mean loss should be
(10*2 + 9*1) / (2 + 1) = 9.66666
. This is shown in the code below.Standalone code to reproduce the issue
Code (gist here)
Output
Where this error occurs in TensorFlow source code
The following line in the
model.test_step()
method calls theself.compiled_loss
object.tensorflow/tensorflow/python/keras/engine/training.py
Lines 971 to 972 in 42052dc
self.compiled_loss
is acompile_utils.LossesContainer
object whose__call__()
method seems to be implemented incorrectly. Specifically, the following line is where each batch's total loss is accumulated over an epoch, but the accumulation is done without any record of the batch size.tensorflow/tensorflow/python/keras/engine/compile_utils.py
Line 235 in 42052dc
Consequently, the mean epoch loss is calculated (
m.result()
below)tensorflow/tensorflow/python/keras/engine/training.py
Line 975 in 42052dc
by dividing the total accumulated loss by the number of batches (
self.count
).tensorflow/tensorflow/python/keras/metrics.py
Line 383 in a5a8cee
Proposed solution
I don't know what the best way to solve this problem may be, but the accumulation of each batch's loss should clearly track each batch's actual size. One possible solution may be to use the
sample_weight
argument and replacetensorflow/tensorflow/python/keras/engine/compile_utils.py
Line 235 in 42052dc
with
self._loss_metric.update_state(total_loss_metric_value, sample_weight=ACTUAL_BATCH_SIZE)
Related Issues
To the best of my knowledge, the problem described above is the root problem for a number of other reported issues: #35585 #35533 #38004 #38165
The text was updated successfully, but these errors were encountered: