Skip to content

Commit

Permalink
Remove the sample_weights placeholder for models that don't have sa…
Browse files Browse the repository at this point in the history
…mple weights.

PiperOrigin-RevId: 245053516
  • Loading branch information
pavithrasv authored and tensorflower-gardener committed Apr 24, 2019
1 parent 92f8457 commit 41abdd1
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 81 deletions.
146 changes: 116 additions & 30 deletions tensorflow/python/keras/engine/training.py
Expand Up @@ -321,8 +321,14 @@ def compile(self,
# TODO(reedwm): Support this.
raise ValueError('We currently do not support enabling `run_eagerly` '
'with a LossScaleOptimizer.')

# Prepare sample weight modes. List with the same length as model outputs.
self._sample_weight_modes = training_utils.prepare_sample_weight_modes(
self.output_names, sample_weight_mode,
self._skip_target_weighing_indices)

# Prepare sample weights.
self._set_sample_weight_attributes(sample_weight_mode)
self._prepare_sample_weights()
# Save all metric attributes per output of the model.
self._cache_output_metric_attributes(metrics, weighted_metrics)

Expand Down Expand Up @@ -412,6 +418,11 @@ def compile(self,
targets=self.targets,
skip_target_indices=skip_target_indices)

# Prepare sample weight modes. List with the same length as model outputs.
self._sample_weight_modes = training_utils.prepare_sample_weight_modes(
self.output_names, sample_weight_mode,
self._skip_target_weighing_indices)

# Creates the model loss and weighted metrics sub-graphs.
self._compile_weights_loss_and_weighted_metrics()

Expand Down Expand Up @@ -804,8 +815,13 @@ def _worker_fn(_):
split_at = int(len(x[0]) * (1. - validation_split))
x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
sample_weights, val_sample_weights = (slice_arrays(
sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
if sample_weights:
sample_weights, val_sample_weights = (
slice_arrays(sample_weights, 0, split_at),
slice_arrays(sample_weights, split_at),
)
else:
val_sample_weights = None
else:
if validation_steps:
raise ValueError('`validation_steps` should not be specified if '
Expand Down Expand Up @@ -1227,6 +1243,7 @@ class during training. This can be useful to tell the model to "pay
if not isinstance(K.symbolic_learning_phase(), int):
ins += [True] # Add learning phase value.

self._update_sample_weight_modes(sample_weights=sample_weights)
self._make_train_function()
outputs = self.train_function(ins) # pylint: disable=not-callable

Expand Down Expand Up @@ -1294,6 +1311,7 @@ def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True):
x = training_utils.ModelInputs(x).as_list()
inputs = x + (y or []) + (sample_weights or [])

self._update_sample_weight_modes(sample_weights=sample_weights)
self._make_test_function()
outputs = self.test_function(inputs) # pylint: disable=not-callable

Expand Down Expand Up @@ -1598,13 +1616,61 @@ def predict_generator(self,
verbose=verbose,
callbacks=callbacks)

def _update_sample_weight_modes(self, sample_weights=None):
"""Updates sample weight modes based on training/eval inputs.
If model contains `_sample_weight_modes` we check if the input
`sample_weights` corresponds to the sample weight modes.
1. If sample weight mode for output i is 'temporal', we do not
change it as the `temporal` mode has been set by the user.
2. Set sample weight mode to be 'samplewise' for output i if sample
weight mode was not set before and sample weight inputs are given.
3. Reset sample weight mode to None for output i if sample weight mode
was set to 'samplewise' but there is no sample weight input.
Args:
sample_weights: List of sample weights of the same length as model outputs
or None.
"""
if not getattr(self, '_sample_weight_modes', []):
return
for i in range(len(self._sample_weight_modes)):
sample_weight = sample_weights[i] if sample_weights else None
if self._sample_weight_modes[i] == 'temporal':
# If sample weight mode for output i is 'temporal', do nothing.
continue
if self._sample_weight_modes[i] is None and sample_weight is not None:
# Set sample weight mode to be 'samplewise' for output i if sample
# weight mode was not set before and sample weight inputs are given.
self._sample_weight_modes[i] = 'samplewise'
elif (self._sample_weight_modes[i] == 'samplewise' and
sample_weight is None):
# Reset sample weight mode to None for output i if sample weight mode
# was set to 'samplewise' but there is no sample weight input.
self._sample_weight_modes[i] = None

def _recompile_weights_loss_and_weighted_metrics(self):
recompile = False
for i, mode in enumerate(self._sample_weight_modes):
if ((mode is not None and self.sample_weights[i] is None) or
(mode is None and self.sample_weights[i] is not None)):
# If there is a mismatch between sample weight mode and the placeholders
# created, then recompile the sub-graphs that depend on sample weights.
recompile = True
break

if recompile:
self._compile_weights_loss_and_weighted_metrics()
return recompile

@trackable.no_automatic_dependency_tracking
def _compile_weights_loss_and_weighted_metrics(self):
"""Compiles the model loss and weighted metric sub-graphs."""

with K.get_graph().as_default():

# Prepare sample weights.
self._set_sample_weight_attributes(self.sample_weight_mode)
self._prepare_sample_weights()

masks = self._prepare_output_masks()
skip_target_indices = self._prepare_skip_target_indices()
Expand All @@ -1615,7 +1681,8 @@ def _compile_weights_loss_and_weighted_metrics(self):
masks=masks,
targets=self.targets,
skip_target_indices=skip_target_indices,
sample_weights=self.sample_weights)
sample_weights=self.sample_weights,
return_weighted_metrics=True)

# Compute total loss.
# Used to keep track of the total loss value (stateless).
Expand Down Expand Up @@ -1870,22 +1937,19 @@ def _list_functions_for_serialization(self):
saving_utils.trace_model_call(self))
return all_functions

def _set_sample_weight_attributes(self, sample_weight_mode):
"""Sets sample weight related attributes on the model."""
sample_weights, sample_weight_modes = training_utils.prepare_sample_weights(
self.output_names, sample_weight_mode,
self._skip_target_weighing_indices)
self.sample_weights = sample_weights
self.sample_weight_modes = sample_weight_modes
self._feed_sample_weight_modes = [
sample_weight_modes[i]
for i in range(len(self.outputs))
if i not in self._skip_target_weighing_indices
]
def _prepare_sample_weights(self):
"""Sets sample weight attribute on the model."""
# List with the same length as model outputs.
self.sample_weights = []
for i, name in enumerate(self.output_names):
self.sample_weights.append(
training_utils.get_output_sample_weight(
self._skip_target_weighing_indices, self._sample_weight_modes[i],
name, i))

# Filtering just the placeholders from the above list.
self._feed_sample_weights = [
sample_weights[i]
for i in range(len(sample_weights))
if i not in self._skip_target_weighing_indices
s for s in self.sample_weights if s is not None
]

def _cache_output_metric_attributes(self, metrics, weighted_metrics):
Expand Down Expand Up @@ -2065,6 +2129,7 @@ def _handle_metrics(self,
targets=None,
sample_weights=None,
masks=None,
return_weighted_metrics=False,
return_weighted_and_unweighted_metrics=False):
"""Handles calling metric functions.
Expand All @@ -2074,10 +2139,13 @@ def _handle_metrics(self,
targets: List of targets.
sample_weights: Optional list of sample weight arrays.
masks: List of computed output mask values.
return_weighted_metrics: Flag that indicates whether weighted metrics
should be computed instead of unweighted metrics. This flag is ignored
when `return_weighted_and_unweighted_metrics` is enabled.
return_weighted_and_unweighted_metrics: Flag that is used to indicate
whether both weighted and unweighted metrics should be computed. When
this is not enabled, we use `sample_weights` param to indicate whether
weighted or unweighted metrics should be returned.
this is not enabled, we use `return_weighted_metrics` param to
indicate whether weighted or unweighted metrics should be returned.
Returns:
A list of metric result tensors.
Expand All @@ -2093,18 +2161,19 @@ def _handle_metrics(self,
target = targets[i] if targets else None
output_mask = masks[i] if masks else None

if return_weighted_and_unweighted_metrics or sample_weights is None:
if (return_weighted_and_unweighted_metrics or
not return_weighted_metrics):
metric_results.extend(
self._handle_per_output_metrics(self._per_output_metrics[i],
target, output, output_mask))
if return_weighted_and_unweighted_metrics or sample_weights is not None:
if return_weighted_and_unweighted_metrics or return_weighted_metrics:
metric_results.extend(
self._handle_per_output_metrics(
self._per_output_weighted_metrics[i],
target,
output,
output_mask,
weights=sample_weights[i]))
weights=sample_weights[i] if sample_weights else None))
return metric_results

def _check_trainable_weights_consistency(self):
Expand All @@ -2126,13 +2195,17 @@ def _check_trainable_weights_consistency(self):
' without calling `model.compile` after ?', 1)

def _make_train_function(self):
has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
metrics_tensors = [
self._all_metrics_tensors[m] for m in self.metrics_names[1:]
]
if not self._is_compiled:
raise RuntimeError('You must compile your model before using it.')
self._check_trainable_weights_consistency()
if getattr(self, 'train_function') is None:
# If we have re-compiled the loss/weighted metric sub-graphs then create
# train function even if one exists already. This is because
# `_feed_sample_weights` list has been updated on re-copmpile.
if getattr(self, 'train_function') is None or has_recompiled:
inputs = (self._feed_inputs +
self._feed_targets +
self._feed_sample_weights)
Expand Down Expand Up @@ -2160,12 +2233,16 @@ def _make_train_function(self):
setattr(self, 'train_function', fn)

def _make_test_function(self):
has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
metrics_tensors = [
self._all_metrics_tensors[m] for m in self.metrics_names[1:]
]
if not self._is_compiled:
raise RuntimeError('You must compile your model before using it.')
if getattr(self, 'test_function') is None:
# If we have re-compiled the loss/weighted metric sub-graphs then create
# test function even if one exists already. This is because
# `_feed_sample_weights` list has been updated on re-copmpile.
if getattr(self, 'test_function') is None or has_recompiled:
inputs = (self._feed_inputs +
self._feed_targets +
self._feed_sample_weights)
Expand Down Expand Up @@ -2536,7 +2613,7 @@ def _standardize_user_data(self,
# mixed symbolic/value inputs.
if (not self.run_eagerly and is_build_called and is_compile_called and
not is_dataset and any(_is_symbolic_tensor(v) for v in all_inputs)):
return [], [], []
return [], [], None

# What follows is input validation and standardization to list format,
# in the case where all inputs are value arrays.
Expand Down Expand Up @@ -2575,7 +2652,7 @@ def _standardize_user_data(self,
feed_sample_weight_modes = [None for _ in self.outputs]
else:
feed_output_names = self._feed_output_names
feed_sample_weight_modes = self._feed_sample_weight_modes
feed_sample_weight_modes = self._sample_weight_modes
feed_output_shapes = []
for output_shape, loss_fn in zip(self._feed_output_shapes,
self._feed_loss_fns):
Expand Down Expand Up @@ -2626,9 +2703,18 @@ def _standardize_user_data(self,
# Additional checks to avoid users mistakenly using improper loss fns.
training_utils.check_loss_and_target_compatibility(
y, self._feed_loss_fns, feed_output_shapes)

# If sample weight mode has not been set and weights are None for all the
# model outputs, return None (we do not create placeholders for
# sample weights) so we do not want to feed any value.
is_sample_weight_mode_set = any(
s is not None for s in feed_sample_weight_modes)
if (not is_sample_weight_mode_set and
all(s is None for s in sample_weights)):
sample_weights = None # If the list contains only None, return None
else:
y = []
sample_weights = []
sample_weights = None

if self.stateful and batch_size:
# Check that for stateful networks, number of samples is a multiple
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/python/keras/engine/training_arrays.py
Expand Up @@ -147,6 +147,8 @@ def model_iteration(model,
learning_phase=(1 if mode == ModeKeys.TRAIN else 0))
scope.__enter__()

model._update_sample_weight_modes(sample_weights=sample_weights)

# Get step function and loop type.
f = _make_execution_function(model, mode)
use_steps = is_dataset or steps_per_epoch is not None
Expand Down
13 changes: 8 additions & 5 deletions tensorflow/python/keras/engine/training_test.py
Expand Up @@ -1291,11 +1291,14 @@ def test_sample_weights(self):
x_train[:batch_size],
y_train[:batch_size],
sample_weight=sample_weight[:batch_size])
ref_score = model.evaluate(x_test, y_test, verbose=0)
if not context.executing_eagerly():
score = model.evaluate(
x_test[test_ids, :], y_test[test_ids, :], verbose=0)
self.assertLess(score[0], ref_score[0])
ref_score = model.evaluate(
x_test, y_test, verbose=0, sample_weight=sample_weight)
score = model.evaluate(
x_test[test_ids, :],
y_test[test_ids, :],
verbose=0,
sample_weight=sample_weight[test_ids])
self.assertLess(score[0], ref_score[0])

@keras_parameterized.run_all_keras_modes
def test_temporal_sample_weights(self):
Expand Down

0 comments on commit 41abdd1

Please sign in to comment.