Skip to content

Commit

Permalink
Adding AdaNet replay. The ability to rerun training without having to…
Browse files Browse the repository at this point in the history
… determine the best candidate for the iteration. A list of best indices from the previous run is provided and honored by AdaNet.

PiperOrigin-RevId: 268683890
  • Loading branch information
hanna-maz authored and cweill committed Sep 12, 2019
1 parent 77199f9 commit b5007ae
Show file tree
Hide file tree
Showing 19 changed files with 343 additions and 14 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

# Current version (0.8.0.dev)
* Under development.
* Adding AdaNet replay. The ability to rerun training without having to determine the best candidate for the iteration. A list of best indices from the previous run is provided and honored by AdaNet.
* TODO: Add official Keras Model support, including Keras layers, Sequential, and Model subclasses for defining subnetworks.
* Introduced `adanet.ensemble.MeanEnsembler` with a basic implementation for taking the mean of logits of subnetworks. This also supports including the mean of last_layer (helpful if subnetworks have same configurations) in the `predictions` and `export_outputs` of the EstimatorSpec.
* **BREAKING CHANGE**: AdaNet now supports arbitrary metrics when choosing the best ensemble. To achieve this, the interface of `adanet.Evaluator` is changing. The `Evaluator.evaluate_adanet_losses(sess, adanet_losses)` function is being replaced with `Evaluator.evaluate(sess, ensemble_metrics)`. The `ensemble_metrics` parameter contains all computed metrics for each candidate ensemble as well as the `adanet_loss`. Code which overrides `evaluate_adanet_losses` must migrate over to use the new `evaluate` method (we suspect that such cases are very rare).
Expand Down
4 changes: 4 additions & 0 deletions adanet/BUILD
Expand Up @@ -17,6 +17,10 @@ py_library(
deps = [
"//adanet/autoensemble",
"//adanet/core",
"//adanet/distributed",
"//adanet/ensemble",
"//adanet/replay",
"//adanet/subnetwork",
],
)

Expand Down
2 changes: 2 additions & 0 deletions adanet/__init__.py
Expand Up @@ -20,6 +20,7 @@

from adanet import distributed
from adanet import ensemble
from adanet import replay
from adanet import subnetwork
from adanet.autoensemble import AutoEnsembleEstimator
from adanet.autoensemble import AutoEnsembleSubestimator
Expand All @@ -43,6 +44,7 @@
"Ensemble",
"Estimator",
"Evaluator",
"replay",
"ReportMaterializer",
"subnetwork",
"Summary",
Expand Down
1 change: 1 addition & 0 deletions adanet/adanet_test.py
Expand Up @@ -49,6 +49,7 @@ def test_public(self):
self.assertIsNotNone(adanet.Estimator)
self.assertIsNotNone(adanet.Evaluator)
self.assertIsNotNone(adanet.MixtureWeightType)
self.assertIsNotNone(adanet.replay.Config)
self.assertIsNotNone(adanet.ReportMaterializer)
self.assertIsNotNone(adanet.Subnetwork)
self.assertIsNotNone(adanet.subnetwork.Builder)
Expand Down
9 changes: 9 additions & 0 deletions adanet/autoensemble/estimator.py
Expand Up @@ -358,6 +358,9 @@ def input_fn_predict:
debug: See :class:`adanet.Estimator`.
enable_ensemble_summaries: See :class:`adanet.Estimator`.
enable_subnetwork_summaries: See :class:`adanet.Estimator`.
global_step_combiner_fn: See :class:`adanet.Estimator`.
max_iterations: See :class:`adanet.Estimator`.
replay_config: See :class:`adanet.Estimator`.
**kwargs: Extra keyword args passed to the parent.
Returns:
Expand Down Expand Up @@ -387,6 +390,9 @@ def __init__(self,
debug=False,
enable_ensemble_summaries=True,
enable_subnetwork_summaries=True,
global_step_combiner_fn=tf.math.reduce_mean,
max_iterations=None,
replay_config=None,
**kwargs):
subnetwork_generator = _GeneratorFromCandidatePool(candidate_pool,
logits_fn, last_layer_fn)
Expand All @@ -406,4 +412,7 @@ def __init__(self,
debug=debug,
enable_ensemble_summaries=enable_ensemble_summaries,
enable_subnetwork_summaries=enable_subnetwork_summaries,
global_step_combiner_fn=global_step_combiner_fn,
max_iterations=max_iterations,
replay_config=replay_config,
**kwargs)
3 changes: 3 additions & 0 deletions adanet/core/BUILD
Expand Up @@ -52,6 +52,7 @@ py_test(
":evaluator",
":report_materializer",
":testing_utils",
"//adanet/replay",
"//adanet/subnetwork",
"@absl_py//absl/testing:parameterized",
],
Expand Down Expand Up @@ -148,6 +149,7 @@ py_test(
name = "iteration_test",
srcs = ["iteration_test.py"],
deps = [
":architecture",
":candidate",
":ensemble_builder",
":iteration",
Expand Down Expand Up @@ -194,6 +196,7 @@ py_test(
srcs = ["ensemble_builder_test.py"],
shard_count = 10,
deps = [
":architecture",
":ensemble_builder",
":summary",
":testing_utils",
Expand Down
30 changes: 28 additions & 2 deletions adanet/core/architecture.py
Expand Up @@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function

import copy
import json


Expand All @@ -28,15 +29,20 @@ class _Architecture(object):
`adanet.subnetwork.Builder` instances that compose the ensemble, the
`adanet.ensemble.Ensembler` that constructed it, as well as the sequence
of states in the search space that led to the construction of this model.
In addition, it stores `replay_indices` A list of indices (an index per
boosting iteration); Holding the index of the ensemble in the candidate list
throughout the run.
It is serializable and deserializable for persistent storage.
"""

def __init__(self, ensemble_candidate_name, ensembler_name, global_step=None):
def __init__(self, ensemble_candidate_name, ensembler_name, global_step=None,
replay_indices=None):
self._ensemble_candidate_name = ensemble_candidate_name
self._ensembler_name = ensembler_name
self._global_step = global_step
self._subnets = []
self._replay_indices = replay_indices or []

@property
def ensemble_candidate_name(self):
Expand Down Expand Up @@ -76,6 +82,17 @@ def subnetworks(self):

return tuple(self._subnets)

@property
def replay_indices(self):
"""The list of replay indices.
Returns:
A list of integers (an integer per boosting iteration); Holding the index
of the ensemble in the candidate list throughout the run
"""

return self._replay_indices

@property
def subnetworks_grouped_by_iteration(self):
"""The component subnetworks grouped by iteration number.
Expand Down Expand Up @@ -105,6 +122,13 @@ def add_subnetwork(self, iteration_number, builder_name):
"""
self._subnets.append((iteration_number, builder_name))

# TODO: Remove setters and getters.
def add_replay_index(self, index):
self._replay_indices.append(index)

def set_replay_indices(self, indices):
self._replay_indices = copy.copy(indices)

def serialize(self, global_step):
"""Returns a string serialization of this object."""

Expand All @@ -115,6 +139,7 @@ def serialize(self, global_step):
"global_step": global_step,
"ensembler_name": self.ensembler_name,
"subnetworks": [],
"replay_indices": self._replay_indices
}
for iteration_number, builder_name in self._subnets:
subnetwork_arch = {
Expand All @@ -139,7 +164,8 @@ def deserialize(serialized_architecture):
ensemble_arch = json.loads(serialized_architecture)
architecture = _Architecture(ensemble_arch["ensemble_candidate_name"],
ensemble_arch["ensembler_name"],
ensemble_arch["global_step"])
ensemble_arch["global_step"],
ensemble_arch["replay_indices"])
for subnet in ensemble_arch["subnetworks"]:
architecture.add_subnetwork(subnet["iteration_number"],
subnet["builder_name"])
Expand Down
17 changes: 12 additions & 5 deletions adanet/core/architecture_test.py
Expand Up @@ -72,8 +72,15 @@ def test_subnetworks_grouped_by_iteration(self, subnetworks, want):
arch.add_subnetwork(*subnetwork)
self.assertEqual(want, arch.subnetworks_grouped_by_iteration)

def test_serialization_lifecycle(self):
def test_set_and_add_replay_index(self):
arch = _Architecture("foo", "dummy_ensembler_name")
arch.set_replay_indices([1, 2, 3])
self.assertAllEqual([1, 2, 3], arch.replay_indices)
arch.add_replay_index(4)
self.assertAllEqual([1, 2, 3, 4], arch.replay_indices)

def test_serialization_lifecycle(self):
arch = _Architecture("foo", "dummy_ensembler_name", replay_indices=[1, 2])
arch.add_subnetwork(0, "linear")
arch.add_subnetwork(0, "dnn")
arch.add_subnetwork(1, "dnn")
Expand All @@ -85,10 +92,10 @@ def test_serialization_lifecycle(self):
serialized = arch.serialize(global_step)
self.assertEqual(
'{"ensemble_candidate_name": "foo", "ensembler_name": '
'"dummy_ensembler_name", "global_step": 100, "subnetworks": '
'[{"builder_name": "linear", "iteration_number": 0}, {"builder_name": '
'"dnn", "iteration_number": 0}, {"builder_name": "dnn", '
'"iteration_number": 1}]}', serialized)
'"dummy_ensembler_name", "global_step": 100, "replay_indices": [1, 2], '
'"subnetworks": [{"builder_name": "linear", "iteration_number": 0}, '
'{"builder_name": "dnn", "iteration_number": 0},'
' {"builder_name": "dnn", "iteration_number": 1}]}', serialized)
deserialized_arch = _Architecture.deserialize(serialized)
self.assertEqual(arch.ensemble_candidate_name,
deserialized_arch.ensemble_candidate_name)
Expand Down
14 changes: 13 additions & 1 deletion adanet/core/ensemble_builder.py
Expand Up @@ -21,6 +21,7 @@

import collections
import contextlib
import copy
import functools
import inspect

Expand Down Expand Up @@ -270,6 +271,7 @@ def build_ensemble_spec(self,
mode,
iteration_number,
labels=None,
my_ensemble_index=None,
previous_ensemble_spec=None):
"""Builds an `_EnsembleSpec` with the given `adanet.ensemble.Candidate`.
Expand All @@ -286,6 +288,8 @@ def build_ensemble_spec(self,
iteration_number: Integer current iteration number.
labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
(for multi-head).
my_ensemble_index: An integer holding the index of the ensemble in the
candidates list of AdaNet.
previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from
iteration t-1. Used for creating the subnetwork train_op.
Expand All @@ -304,7 +308,15 @@ def build_ensemble_spec(self,
step_tensor = tf.convert_to_tensor(value=step)
with summary.current_scope():
summary.scalar("iteration_step/adanet/iteration_step", step_tensor)
architecture = _Architecture(candidate.name, ensembler.name)
replay_indices = []
if previous_ensemble_spec:
replay_indices = copy.copy(
previous_ensemble_spec.architecture.replay_indices)
if my_ensemble_index is not None:
replay_indices.append(my_ensemble_index)

architecture = _Architecture(candidate.name, ensembler.name,
replay_indices=replay_indices)
previous_subnetworks = []
subnetwork_builders = []
previous_ensemble = None
Expand Down
16 changes: 16 additions & 0 deletions adanet/core/ensemble_builder_test.py
Expand Up @@ -313,6 +313,15 @@ class EnsembleBuilderTest(tu.AdanetTestCase):
"want_ensemble_trainable_vars": 2,
"want_subnetwork_trainable_vars": 4,
"export_subnetworks": True,
}, {
"testcase_name": "replay_no_prev",
"adanet_beta": .1,
"want_logits": [[.006], [.082]],
"want_loss": 1.349,
"want_adanet_loss": 1.360,
"want_ensemble_trainable_vars": 1,
"my_ensemble_index": 2,
"want_replay_indices": [2],
})
@test_util.run_in_graph_and_eager_modes
def test_build_ensemble_spec(
Expand All @@ -334,6 +343,8 @@ def test_build_ensemble_spec(
multi_head=False,
want_subnetwork_trainable_vars=2,
ensembler_class=ComplexityRegularizedEnsembler,
my_ensemble_index=None,
want_replay_indices=None,
want_predictions=None,
export_subnetworks=False):
seed = 64
Expand Down Expand Up @@ -453,8 +464,13 @@ def _mixture_weights_train_op_fn(loss, var_list):
features=features,
iteration_number=1,
labels=labels,
my_ensemble_index=my_ensemble_index,
mode=mode)

if want_replay_indices:
self.assertAllEqual(want_replay_indices,
ensemble_spec.architecture.replay_indices)

with tf_compat.v1.Session(graph=g).as_default() as sess:
sess.run(tf_compat.v1.global_variables_initializer())

Expand Down
26 changes: 26 additions & 0 deletions adanet/core/estimator.py
Expand Up @@ -411,6 +411,13 @@ class Estimator(tf.estimator.Estimator):
export_subnetwork_logits: Whether to include subnetwork logits in exports.
export_subnetwork_last_layer: Whether to include subnetwork last layer in
exports.
replay_config: Optional :class:`adanet.replay.Config` to specify a previous
AdaNet run to replay. Given the exact same search space but potentially
different training data, the `replay_config` causes the estimator to
reconstruct the previously trained model without performing a search.
NOTE: The previous run must have executed with identical hyperparameters
as the new run in order to be replayable. The only supported difference is
that the underlying data can change.
**kwargs: Extra keyword args passed to the parent.
Returns:
Expand Down Expand Up @@ -455,6 +462,7 @@ def __init__(self,
max_iterations=None,
export_subnetwork_logits=False,
export_subnetwork_last_layer=True,
replay_config=None,
**kwargs):
if subnetwork_generator is None:
raise ValueError("subnetwork_generator can't be None.")
Expand Down Expand Up @@ -487,6 +495,7 @@ def __init__(self,
self._worker_wait_secs = worker_wait_secs
self._worker_wait_timeout_secs = worker_wait_timeout_secs
self._max_iterations = max_iterations
self._replay_config = replay_config

# Added for backwards compatibility.
default_ensembler_args = [
Expand Down Expand Up @@ -912,6 +921,15 @@ def experimental_export_all_saved_models(self,
def _compute_best_ensemble_index(self, checkpoint_path):
"""Runs the Evaluator to obtain the best ensemble index among candidates."""

# AdaNet Replay.
if self._replay_config:
iteration_number = (
self._checkpoint_path_iteration_number(checkpoint_path)
if checkpoint_path else self._latest_checkpoint_iteration_number())
best_index = self._replay_config.get_best_ensemble_index(iteration_number)
if best_index is not None:
return best_index

if self._evaluator:
return self._execute_candidate_evaluation_phase(
self._evaluator.input_fn,
Expand Down Expand Up @@ -1143,6 +1161,12 @@ def _get_best_ensemble_index(self, current_iteration, input_hooks):
Returns:
Index of the best ensemble in the iteration's list of `_Candidates`.
"""
# AdaNet Replay.
if self._replay_config:
best_index = self._replay_config.get_best_ensemble_index(
current_iteration.number)
if best_index is not None:
return best_index

# Skip the evaluation phase when there is only one candidate subnetwork.
if len(current_iteration.candidates) == 1:
Expand Down Expand Up @@ -1538,6 +1562,8 @@ def _architecture_ensemble_spec(self, architecture, iteration_number,
assert len(current_iteration.candidates) == max_candidates
previous_ensemble_spec = current_iteration.candidates[-1].ensemble_spec
previous_ensemble = previous_ensemble_spec.ensemble
previous_ensemble_spec.architecture.set_replay_indices(
architecture.replay_indices)
return previous_ensemble_spec

def _collate_subnetwork_reports(self, iteration_number):
Expand Down

0 comments on commit b5007ae

Please sign in to comment.