Skip to content

Commit

Permalink
Fix last partial batch loss regression in 2.2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 307666011
Change-Id: I4ede295280b78e18b5b8b52f0c211d5c0a7913e2
  • Loading branch information
pavithrasv authored and tensorflower-gardener committed Apr 21, 2020
1 parent 81734a8 commit 4f17f35
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 101 deletions.
9 changes: 7 additions & 2 deletions tensorflow/python/keras/engine/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __call__(self,

loss_values = [] # Used for gradient calculation.
loss_metric_values = [] # Used for loss metric calculation.
batch_dim = None
zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights,
self._per_output_metrics)
for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
Expand All @@ -207,8 +208,11 @@ def __call__(self,
# Correct for the `Mean` loss metrics counting each replica as a batch.
if loss_obj.reduction == losses_utils.ReductionV2.SUM:
loss_metric_value *= ds_context.get_strategy().num_replicas_in_sync

if batch_dim is None:
batch_dim = array_ops.shape(y_t)[0]
if metric_obj is not None:
metric_obj.update_state(loss_metric_value)
metric_obj.update_state(loss_metric_value, sample_weight=batch_dim)

if loss_weight is not None:
loss_value *= loss_weight
Expand All @@ -232,7 +236,8 @@ def __call__(self,
loss_metric_values = losses_utils.cast_losses_to_common_dtype(
loss_metric_values)
total_loss_metric_value = math_ops.add_n(loss_metric_values)
self._loss_metric.update_state(total_loss_metric_value)
self._loss_metric.update_state(
total_loss_metric_value, sample_weight=batch_dim)

loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
total_loss = math_ops.add_n(loss_values)
Expand Down

0 comments on commit 4f17f35

Please sign in to comment.