diff --git a/RELEASE.md b/RELEASE.md index fa4437f8..3be1549c 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -15,6 +15,7 @@ limitations under the License. # Current version (0.8.0.dev) * Under development. + * Adding AdaNet replay. The ability to rerun training without having to determine the best candidate for the iteration. A list of best indices from the previous run is provided and honored by AdaNet. * TODO: Add official Keras Model support, including Keras layers, Sequential, and Model subclasses for defining subnetworks. * Introduced `adanet.ensemble.MeanEnsembler` with a basic implementation for taking the mean of logits of subnetworks. This also supports including the mean of last_layer (helpful if subnetworks have same configurations) in the `predictions` and `export_outputs` of the EstimatorSpec. * **BREAKING CHANGE**: AdaNet now supports arbitrary metrics when choosing the best ensemble. To achieve this, the interface of `adanet.Evaluator` is changing. The `Evaluator.evaluate_adanet_losses(sess, adanet_losses)` function is being replaced with `Evaluator.evaluate(sess, ensemble_metrics)`. The `ensemble_metrics` parameter contains all computed metrics for each candidate ensemble as well as the `adanet_loss`. Code which overrides `evaluate_adanet_losses` must migrate over to use the new `evaluate` method (we suspect that such cases are very rare). diff --git a/adanet/BUILD b/adanet/BUILD index cdda5499..76e20943 100644 --- a/adanet/BUILD +++ b/adanet/BUILD @@ -17,6 +17,10 @@ py_library( deps = [ "//adanet/autoensemble", "//adanet/core", + "//adanet/distributed", + "//adanet/ensemble", + "//adanet/replay", + "//adanet/subnetwork", ], ) diff --git a/adanet/__init__.py b/adanet/__init__.py index 0a89346a..1d4fb189 100644 --- a/adanet/__init__.py +++ b/adanet/__init__.py @@ -20,6 +20,7 @@ from adanet import distributed from adanet import ensemble +from adanet import replay from adanet import subnetwork from adanet.autoensemble import AutoEnsembleEstimator from adanet.autoensemble import AutoEnsembleSubestimator @@ -43,6 +44,7 @@ "Ensemble", "Estimator", "Evaluator", + "replay", "ReportMaterializer", "subnetwork", "Summary", diff --git a/adanet/adanet_test.py b/adanet/adanet_test.py index 4f83c056..fe0c3586 100644 --- a/adanet/adanet_test.py +++ b/adanet/adanet_test.py @@ -49,6 +49,7 @@ def test_public(self): self.assertIsNotNone(adanet.Estimator) self.assertIsNotNone(adanet.Evaluator) self.assertIsNotNone(adanet.MixtureWeightType) + self.assertIsNotNone(adanet.replay.Config) self.assertIsNotNone(adanet.ReportMaterializer) self.assertIsNotNone(adanet.Subnetwork) self.assertIsNotNone(adanet.subnetwork.Builder) diff --git a/adanet/autoensemble/estimator.py b/adanet/autoensemble/estimator.py index a00a656e..655c6773 100644 --- a/adanet/autoensemble/estimator.py +++ b/adanet/autoensemble/estimator.py @@ -358,6 +358,9 @@ def input_fn_predict: debug: See :class:`adanet.Estimator`. enable_ensemble_summaries: See :class:`adanet.Estimator`. enable_subnetwork_summaries: See :class:`adanet.Estimator`. + global_step_combiner_fn: See :class:`adanet.Estimator`. + max_iterations: See :class:`adanet.Estimator`. + replay_config: See :class:`adanet.Estimator`. **kwargs: Extra keyword args passed to the parent. Returns: @@ -387,6 +390,9 @@ def __init__(self, debug=False, enable_ensemble_summaries=True, enable_subnetwork_summaries=True, + global_step_combiner_fn=tf.math.reduce_mean, + max_iterations=None, + replay_config=None, **kwargs): subnetwork_generator = _GeneratorFromCandidatePool(candidate_pool, logits_fn, last_layer_fn) @@ -406,4 +412,7 @@ def __init__(self, debug=debug, enable_ensemble_summaries=enable_ensemble_summaries, enable_subnetwork_summaries=enable_subnetwork_summaries, + global_step_combiner_fn=global_step_combiner_fn, + max_iterations=max_iterations, + replay_config=replay_config, **kwargs) diff --git a/adanet/core/BUILD b/adanet/core/BUILD index 2dcb1126..fab8f8ed 100644 --- a/adanet/core/BUILD +++ b/adanet/core/BUILD @@ -52,6 +52,7 @@ py_test( ":evaluator", ":report_materializer", ":testing_utils", + "//adanet/replay", "//adanet/subnetwork", "@absl_py//absl/testing:parameterized", ], @@ -148,6 +149,7 @@ py_test( name = "iteration_test", srcs = ["iteration_test.py"], deps = [ + ":architecture", ":candidate", ":ensemble_builder", ":iteration", @@ -194,6 +196,7 @@ py_test( srcs = ["ensemble_builder_test.py"], shard_count = 10, deps = [ + ":architecture", ":ensemble_builder", ":summary", ":testing_utils", diff --git a/adanet/core/architecture.py b/adanet/core/architecture.py index 906af217..1b18aafa 100644 --- a/adanet/core/architecture.py +++ b/adanet/core/architecture.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function +import copy import json @@ -28,15 +29,20 @@ class _Architecture(object): `adanet.subnetwork.Builder` instances that compose the ensemble, the `adanet.ensemble.Ensembler` that constructed it, as well as the sequence of states in the search space that led to the construction of this model. + In addition, it stores `replay_indices` A list of indices (an index per + boosting iteration); Holding the index of the ensemble in the candidate list + throughout the run. It is serializable and deserializable for persistent storage. """ - def __init__(self, ensemble_candidate_name, ensembler_name, global_step=None): + def __init__(self, ensemble_candidate_name, ensembler_name, global_step=None, + replay_indices=None): self._ensemble_candidate_name = ensemble_candidate_name self._ensembler_name = ensembler_name self._global_step = global_step self._subnets = [] + self._replay_indices = replay_indices or [] @property def ensemble_candidate_name(self): @@ -76,6 +82,17 @@ def subnetworks(self): return tuple(self._subnets) + @property + def replay_indices(self): + """The list of replay indices. + + Returns: + A list of integers (an integer per boosting iteration); Holding the index + of the ensemble in the candidate list throughout the run + """ + + return self._replay_indices + @property def subnetworks_grouped_by_iteration(self): """The component subnetworks grouped by iteration number. @@ -105,6 +122,13 @@ def add_subnetwork(self, iteration_number, builder_name): """ self._subnets.append((iteration_number, builder_name)) + # TODO: Remove setters and getters. + def add_replay_index(self, index): + self._replay_indices.append(index) + + def set_replay_indices(self, indices): + self._replay_indices = copy.copy(indices) + def serialize(self, global_step): """Returns a string serialization of this object.""" @@ -115,6 +139,7 @@ def serialize(self, global_step): "global_step": global_step, "ensembler_name": self.ensembler_name, "subnetworks": [], + "replay_indices": self._replay_indices } for iteration_number, builder_name in self._subnets: subnetwork_arch = { @@ -139,7 +164,8 @@ def deserialize(serialized_architecture): ensemble_arch = json.loads(serialized_architecture) architecture = _Architecture(ensemble_arch["ensemble_candidate_name"], ensemble_arch["ensembler_name"], - ensemble_arch["global_step"]) + ensemble_arch["global_step"], + ensemble_arch["replay_indices"]) for subnet in ensemble_arch["subnetworks"]: architecture.add_subnetwork(subnet["iteration_number"], subnet["builder_name"]) diff --git a/adanet/core/architecture_test.py b/adanet/core/architecture_test.py index 6b357352..7ccf77b2 100644 --- a/adanet/core/architecture_test.py +++ b/adanet/core/architecture_test.py @@ -72,8 +72,15 @@ def test_subnetworks_grouped_by_iteration(self, subnetworks, want): arch.add_subnetwork(*subnetwork) self.assertEqual(want, arch.subnetworks_grouped_by_iteration) - def test_serialization_lifecycle(self): + def test_set_and_add_replay_index(self): arch = _Architecture("foo", "dummy_ensembler_name") + arch.set_replay_indices([1, 2, 3]) + self.assertAllEqual([1, 2, 3], arch.replay_indices) + arch.add_replay_index(4) + self.assertAllEqual([1, 2, 3, 4], arch.replay_indices) + + def test_serialization_lifecycle(self): + arch = _Architecture("foo", "dummy_ensembler_name", replay_indices=[1, 2]) arch.add_subnetwork(0, "linear") arch.add_subnetwork(0, "dnn") arch.add_subnetwork(1, "dnn") @@ -85,10 +92,10 @@ def test_serialization_lifecycle(self): serialized = arch.serialize(global_step) self.assertEqual( '{"ensemble_candidate_name": "foo", "ensembler_name": ' - '"dummy_ensembler_name", "global_step": 100, "subnetworks": ' - '[{"builder_name": "linear", "iteration_number": 0}, {"builder_name": ' - '"dnn", "iteration_number": 0}, {"builder_name": "dnn", ' - '"iteration_number": 1}]}', serialized) + '"dummy_ensembler_name", "global_step": 100, "replay_indices": [1, 2], ' + '"subnetworks": [{"builder_name": "linear", "iteration_number": 0}, ' + '{"builder_name": "dnn", "iteration_number": 0},' + ' {"builder_name": "dnn", "iteration_number": 1}]}', serialized) deserialized_arch = _Architecture.deserialize(serialized) self.assertEqual(arch.ensemble_candidate_name, deserialized_arch.ensemble_candidate_name) diff --git a/adanet/core/ensemble_builder.py b/adanet/core/ensemble_builder.py index d2dc843c..4c4a90f4 100644 --- a/adanet/core/ensemble_builder.py +++ b/adanet/core/ensemble_builder.py @@ -21,6 +21,7 @@ import collections import contextlib +import copy import functools import inspect @@ -270,6 +271,7 @@ def build_ensemble_spec(self, mode, iteration_number, labels=None, + my_ensemble_index=None, previous_ensemble_spec=None): """Builds an `_EnsembleSpec` with the given `adanet.ensemble.Candidate`. @@ -286,6 +288,8 @@ def build_ensemble_spec(self, iteration_number: Integer current iteration number. labels: Labels `Tensor` or a dictionary of string label name to `Tensor` (for multi-head). + my_ensemble_index: An integer holding the index of the ensemble in the + candidates list of AdaNet. previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from iteration t-1. Used for creating the subnetwork train_op. @@ -304,7 +308,15 @@ def build_ensemble_spec(self, step_tensor = tf.convert_to_tensor(value=step) with summary.current_scope(): summary.scalar("iteration_step/adanet/iteration_step", step_tensor) - architecture = _Architecture(candidate.name, ensembler.name) + replay_indices = [] + if previous_ensemble_spec: + replay_indices = copy.copy( + previous_ensemble_spec.architecture.replay_indices) + if my_ensemble_index is not None: + replay_indices.append(my_ensemble_index) + + architecture = _Architecture(candidate.name, ensembler.name, + replay_indices=replay_indices) previous_subnetworks = [] subnetwork_builders = [] previous_ensemble = None diff --git a/adanet/core/ensemble_builder_test.py b/adanet/core/ensemble_builder_test.py index fd5f1914..410108d0 100644 --- a/adanet/core/ensemble_builder_test.py +++ b/adanet/core/ensemble_builder_test.py @@ -313,6 +313,15 @@ class EnsembleBuilderTest(tu.AdanetTestCase): "want_ensemble_trainable_vars": 2, "want_subnetwork_trainable_vars": 4, "export_subnetworks": True, + }, { + "testcase_name": "replay_no_prev", + "adanet_beta": .1, + "want_logits": [[.006], [.082]], + "want_loss": 1.349, + "want_adanet_loss": 1.360, + "want_ensemble_trainable_vars": 1, + "my_ensemble_index": 2, + "want_replay_indices": [2], }) @test_util.run_in_graph_and_eager_modes def test_build_ensemble_spec( @@ -334,6 +343,8 @@ def test_build_ensemble_spec( multi_head=False, want_subnetwork_trainable_vars=2, ensembler_class=ComplexityRegularizedEnsembler, + my_ensemble_index=None, + want_replay_indices=None, want_predictions=None, export_subnetworks=False): seed = 64 @@ -453,8 +464,13 @@ def _mixture_weights_train_op_fn(loss, var_list): features=features, iteration_number=1, labels=labels, + my_ensemble_index=my_ensemble_index, mode=mode) + if want_replay_indices: + self.assertAllEqual(want_replay_indices, + ensemble_spec.architecture.replay_indices) + with tf_compat.v1.Session(graph=g).as_default() as sess: sess.run(tf_compat.v1.global_variables_initializer()) diff --git a/adanet/core/estimator.py b/adanet/core/estimator.py index efee8fbd..03b4e685 100644 --- a/adanet/core/estimator.py +++ b/adanet/core/estimator.py @@ -411,6 +411,13 @@ class Estimator(tf.estimator.Estimator): export_subnetwork_logits: Whether to include subnetwork logits in exports. export_subnetwork_last_layer: Whether to include subnetwork last layer in exports. + replay_config: Optional :class:`adanet.replay.Config` to specify a previous + AdaNet run to replay. Given the exact same search space but potentially + different training data, the `replay_config` causes the estimator to + reconstruct the previously trained model without performing a search. + NOTE: The previous run must have executed with identical hyperparameters + as the new run in order to be replayable. The only supported difference is + that the underlying data can change. **kwargs: Extra keyword args passed to the parent. Returns: @@ -455,6 +462,7 @@ def __init__(self, max_iterations=None, export_subnetwork_logits=False, export_subnetwork_last_layer=True, + replay_config=None, **kwargs): if subnetwork_generator is None: raise ValueError("subnetwork_generator can't be None.") @@ -487,6 +495,7 @@ def __init__(self, self._worker_wait_secs = worker_wait_secs self._worker_wait_timeout_secs = worker_wait_timeout_secs self._max_iterations = max_iterations + self._replay_config = replay_config # Added for backwards compatibility. default_ensembler_args = [ @@ -912,6 +921,15 @@ def experimental_export_all_saved_models(self, def _compute_best_ensemble_index(self, checkpoint_path): """Runs the Evaluator to obtain the best ensemble index among candidates.""" + # AdaNet Replay. + if self._replay_config: + iteration_number = ( + self._checkpoint_path_iteration_number(checkpoint_path) + if checkpoint_path else self._latest_checkpoint_iteration_number()) + best_index = self._replay_config.get_best_ensemble_index(iteration_number) + if best_index is not None: + return best_index + if self._evaluator: return self._execute_candidate_evaluation_phase( self._evaluator.input_fn, @@ -1143,6 +1161,12 @@ def _get_best_ensemble_index(self, current_iteration, input_hooks): Returns: Index of the best ensemble in the iteration's list of `_Candidates`. """ + # AdaNet Replay. + if self._replay_config: + best_index = self._replay_config.get_best_ensemble_index( + current_iteration.number) + if best_index is not None: + return best_index # Skip the evaluation phase when there is only one candidate subnetwork. if len(current_iteration.candidates) == 1: @@ -1538,6 +1562,8 @@ def _architecture_ensemble_spec(self, architecture, iteration_number, assert len(current_iteration.candidates) == max_candidates previous_ensemble_spec = current_iteration.candidates[-1].ensemble_spec previous_ensemble = previous_ensemble_spec.ensemble + previous_ensemble_spec.architecture.set_replay_indices( + architecture.replay_indices) return previous_ensemble_spec def _collate_subnetwork_reports(self, iteration_number): diff --git a/adanet/core/estimator_test.py b/adanet/core/estimator_test.py index 9f7c69ea..b6b023de 100644 --- a/adanet/core/estimator_test.py +++ b/adanet/core/estimator_test.py @@ -25,6 +25,7 @@ from absl import logging from absl.testing import parameterized +from adanet import replay from adanet import tf_compat from adanet.core import testing_utils as tu from adanet.core.estimator import Estimator @@ -3141,5 +3142,85 @@ def test_train(self): + +class EstimatorReplayTest(tu.AdanetTestCase): + + @parameterized.named_parameters( + { + "testcase_name": "no_evaluator", + "evaluator": None, + "replay_evaluator": None, + "want_architecture": " dnn3 | dnn3 | dnn ", + }, { + "testcase_name": + "evaluator", + "evaluator": + Evaluator( + input_fn=tu.dummy_input_fn(XOR_FEATURES, XOR_LABELS), + steps=1), + "replay_evaluator": + Evaluator( + input_fn=tu.dummy_input_fn([[0., 0.], [0., 0], [0., 0.], + [0., 0.]], [[0], [0], [0], [0]]), + steps=1), + "want_architecture": + " dnn3 | dnn3 | dnn ", + }) + def test_replay(self, evaluator, replay_evaluator, want_architecture): + """Train entire estimator lifecycle using Replay.""" + + original_model_dir = os.path.join(self.test_subdirectory, "original") + run_config = tf.estimator.RunConfig( + tf_random_seed=42, model_dir=original_model_dir) + subnetwork_generator = SimpleGenerator([ + _DNNBuilder("dnn"), + _DNNBuilder("dnn2", layer_size=3), + _DNNBuilder("dnn3", layer_size=5), + ]) + estimator = Estimator( + head=tu.head(), + subnetwork_generator=subnetwork_generator, + max_iteration_steps=10, + evaluator=evaluator, + config=run_config) + + train_input_fn = tu.dummy_input_fn(XOR_FEATURES, XOR_LABELS) + + # Train for three iterations. + estimator.train(input_fn=train_input_fn, max_steps=30) + + # Evaluate. + eval_results = estimator.evaluate(input_fn=train_input_fn, steps=1) + self.assertIn(want_architecture, + str(eval_results["architecture/adanet/ensembles"])) + + replay_run_config = tf.estimator.RunConfig( + tf_random_seed=42, + model_dir=os.path.join(self.test_subdirectory, "replayed")) + + # Use different features and labels to represent a shift in the data + # distribution. + different_features = [[0., 0.], [0., 0], [0., 0.], [0., 0.]] + different_labels = [[0], [0], [0], [0]] + + replay_estimator = Estimator( + head=tu.head(), + subnetwork_generator=subnetwork_generator, + max_iteration_steps=10, + evaluator=replay_evaluator, + config=replay_run_config, + replay_config=replay.Config(best_ensemble_indices=[2, 3, 1])) + + train_input_fn = tu.dummy_input_fn(different_features, different_labels) + + # Train for three iterations. + replay_estimator.train(input_fn=train_input_fn, max_steps=30) + + # Evaluate. + eval_results = replay_estimator.evaluate(input_fn=train_input_fn, steps=1) + self.assertIn(want_architecture, + str(eval_results["architecture/adanet/ensembles"])) + + if __name__ == "__main__": tf.test.main() diff --git a/adanet/core/eval_metrics.py b/adanet/core/eval_metrics.py index 2d32eabf..ed0f373c 100644 --- a/adanet/core/eval_metrics.py +++ b/adanet/core/eval_metrics.py @@ -291,10 +291,15 @@ def _architecture_metric_fn(): class _IterationMetrics(object): """A object which creates evaluation metrics for an Iteration.""" - def __init__(self, iteration_number, candidates, subnetwork_specs): + def __init__(self, + iteration_number, + candidates, + subnetwork_specs, + replay_indices_for_all=None): self._iteration_number = iteration_number self._candidates = candidates self._subnetwork_specs = subnetwork_specs + self._replay_indices_for_all = replay_indices_for_all self._candidates_eval_metrics_store = self._build_eval_metrics_store( [candidate.ensemble_spec for candidate in self._candidates]) @@ -343,6 +348,26 @@ def best_eval_metrics_tuple(self, best_candidate_index, mode): args = candidate_args + subnetwork_args args.append(tf.reshape(best_candidate_index, [1])) + def _replay_eval_metrics(best_candidate_idx, eval_metric_ops): + """Saves replay indices as eval metrics.""" + # _replay_indices_for_all is a dict: {candidate: [list of replay_indices]} + # We are finding the max length replay list. + pad_value = max( + [len(v) for _, v in self._replay_indices_for_all.items()]) + + # Creating a matrix of (#candidate) times (max length replay indices). + # Entry i,j is the jth replay index of the ith candidate (ensemble). + replay_indices_as_tensor = tf.constant([ + value + [-1] * (pad_value - len(value)) + for _, value in self._replay_indices_for_all.items() + ]) + + # Passing the right entries (entries of the best candidate). + for iteration in range(replay_indices_as_tensor.get_shape()[1].value): + index_t = replay_indices_as_tensor[best_candidate_idx, iteration] + eval_metric_ops["best_ensemble_index_{}".format(iteration)] = (index_t, + index_t) + def _best_eval_metrics_fn(*args): """Returns the best eval metrics.""" @@ -355,6 +380,7 @@ def _best_eval_metrics_fn(*args): else: idx, idx_update_op = tf_compat.v1.metrics.mean(args.pop()) + idx = tf.cast(idx, tf.int32) metric_fns = self._candidates_eval_metrics_store.metric_fns metric_fn_args = self._candidates_eval_metrics_store.pack_args( args[:len(candidate_args)]) @@ -376,11 +402,11 @@ def _best_eval_metrics_fn(*args): continue if tf.executing_eagerly(): values = [m.result() for m in metric_ops] - best_value = tf.stack(values)[tf.cast(idx, tf.int32)] + best_value = tf.stack(values)[idx] eval_metric_ops[metric_name] = (best_value, None) continue values, ops = list(six.moves.zip(*metric_ops)) - best_value = tf.stack(values)[tf.cast(idx, tf.int32)] + best_value = tf.stack(values)[idx] # All tensors in this function have been outfed from the TPU, so we # must update them manually, otherwise the TPU will hang indefinitely # for the value of idx to update. @@ -395,6 +421,9 @@ def _best_eval_metrics_fn(*args): iteration_number = tf.constant(self._iteration_number) eval_metric_ops["iteration"] = (iteration_number, iteration_number) + if self._replay_indices_for_all: + _replay_eval_metrics(idx, eval_metric_ops) + # tf.estimator.Estimator does not allow a "loss" key to be present in # its eval_metrics. assert "loss" not in eval_metric_ops diff --git a/adanet/core/iteration.py b/adanet/core/iteration.py index 05f8f508..4e27b5f8 100644 --- a/adanet/core/iteration.py +++ b/adanet/core/iteration.py @@ -21,6 +21,7 @@ import collections import contextlib +import copy import json import os @@ -553,7 +554,7 @@ def build_iteration(self, builder_mode = mode features, labels = self._check_numerics(features, labels) - + replay_indices_for_all = {} training = mode == tf.estimator.ModeKeys.TRAIN skip_summaries = mode == tf.estimator.ModeKeys.PREDICT or rebuilding with tf_compat.v1.variable_scope("iteration_{}".format(iteration_number)): @@ -565,6 +566,8 @@ def build_iteration(self, if previous_ensemble_spec: previous_ensemble = previous_ensemble_spec.ensemble + replay_indices_for_all[len(candidates)] = copy.copy( + previous_ensemble_spec.architecture.replay_indices) # Include previous best subnetwork as a candidate so that its # predictions are returned until a new candidate outperforms. seen_builder_names = {previous_ensemble_spec.name: True} @@ -680,12 +683,15 @@ def build_iteration(self, mode=builder_mode, iteration_number=iteration_number, labels=labels, + my_ensemble_index=len(candidates), previous_ensemble_spec=previous_ensemble_spec) # TODO: Eliminate need for candidates. # TODO: Don't track moving average of loss when rebuilding # previous ensemble. candidate = self._candidate_builder.build_candidate( ensemble_spec=ensemble_spec, training=training, summary=summary) + replay_indices_for_all[len(candidates)] = copy.copy( + ensemble_spec.architecture.replay_indices) candidates.append(candidate) # TODO: Move adanet_loss from subnetwork report to a new # ensemble report, since the adanet_loss is associated with an @@ -725,7 +731,8 @@ def build_iteration(self, if best_loss is not None: summary.scalar("loss", best_loss) iteration_metrics = _IterationMetrics(iteration_number, candidates, - subnetwork_specs) + subnetwork_specs, + replay_indices_for_all) if self._use_tpu: estimator_spec = tf_compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, diff --git a/adanet/core/iteration_test.py b/adanet/core/iteration_test.py index 86f0d447..360cca2b 100644 --- a/adanet/core/iteration_test.py +++ b/adanet/core/iteration_test.py @@ -23,6 +23,7 @@ from absl.testing import parameterized from adanet import tf_compat +from adanet.core.architecture import _Architecture from adanet.core.candidate import _Candidate from adanet.core.ensemble_builder import _EnsembleSpec from adanet.core.ensemble_builder import _SubnetworkSpec @@ -252,6 +253,7 @@ def build_ensemble_spec(self, iteration_number, labels=None, previous_ensemble_spec=None, + my_ensemble_index=None, params=None): del ensembler del subnetwork_specs @@ -261,6 +263,7 @@ def build_ensemble_spec(self, del labels del iteration_number del params + del my_ensemble_index num_subnetworks = 0 if previous_ensemble_spec: @@ -861,12 +864,14 @@ def build_ensemble_spec(self, iteration_number, labels=None, previous_ensemble_spec=None, + my_ensemble_index=None, params=None): del ensembler del subnetwork_specs del summary del iteration_number del previous_ensemble_spec + del my_ensemble_index del params logits = [[.5]] @@ -876,7 +881,7 @@ def build_ensemble_spec(self, return _EnsembleSpec( name=name, ensemble=None, - architecture=None, + architecture=_Architecture("foo", "bar"), subnetwork_builders=candidate.subnetwork_builders, predictions=estimator_spec.predictions, step=tf.Variable(0), diff --git a/adanet/core/tpu_estimator.py b/adanet/core/tpu_estimator.py index 08d7f134..14cfc1e5 100644 --- a/adanet/core/tpu_estimator.py +++ b/adanet/core/tpu_estimator.py @@ -71,6 +71,9 @@ class TPUEstimator(Estimator, tf_compat.v1.estimator.tpu.TPUEstimator): export_subnetwork_logits: Whether to include subnetwork logits in exports. export_subnetwork_last_layer: Whether to include subnetwork last layer in exports. + global_step_combiner_fn: See :class:`adanet.Estimator`. + max_iterations: See :class:`adanet.Estimator`. + replay_config: See :class:`adanet.Estimator`. **kwargs: Extra keyword args passed to the parent. """ @@ -98,6 +101,9 @@ def __init__(self, enable_subnetwork_summaries=True, export_subnetwork_logits=False, export_subnetwork_last_layer=True, + global_step_combiner_fn=tf.math.reduce_mean, + max_iterations=None, + replay_config=None, **kwargs): if tf_compat.version_greater_or_equal("2.0.0"): @@ -142,6 +148,9 @@ def __init__(self, enable_subnetwork_summaries=enable_subnetwork_summaries, export_subnetwork_logits=export_subnetwork_logits, export_subnetwork_last_layer=export_subnetwork_last_layer, + global_step_combiner_fn=global_step_combiner_fn, + max_iterations=max_iterations, + replay_config=replay_config, **kwargs) # Yields predictions on CPU even when use_tpu=True. diff --git a/adanet/replay/BUILD b/adanet/replay/BUILD new file mode 100644 index 00000000..e92f4b5c --- /dev/null +++ b/adanet/replay/BUILD @@ -0,0 +1,14 @@ +# Description: +# AdaNet replay. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "replay", + srcs = ["__init__.py"], + visibility = ["//adanet:__subpackages__"], + deps = [ + ], +) diff --git a/adanet/replay/__init__.py b/adanet/replay/__init__.py new file mode 100644 index 00000000..77226937 --- /dev/null +++ b/adanet/replay/__init__.py @@ -0,0 +1,62 @@ +# Copyright 2019 The AdaNet Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines mechanisms for deterministically replaying an AdaNet model search.""" + +# TODO: Add more detailed documentation. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +import tensorflow as tf + + +class Config(object): # pylint: disable=g-classes-have-attributes + # pyformat: disable + """Defines how to deterministically replay an AdaNet model search. + + Specifically, it reconstructs the previous model and trains its components + in the correct order without performing any search. + + Args: + best_ensemble_indices: A list of the best ensemble indices (one per + iteration). + + Returns: + An :class:`adanet.replay.Config` instance. + """ + # pyformat: enable + + def __init__(self, best_ensemble_indices=None): + self._best_ensemble_indices = best_ensemble_indices + + @property + def best_ensemble_indices(self): + """The best ensemble indices per iteration.""" + return self._best_ensemble_indices + + def get_best_ensemble_index(self, iteration_number): + """Returns the best ensemble index given an iteration number.""" + # If we are provided the list + if (self._best_ensemble_indices + and iteration_number < len(self._best_ensemble_indices)): + return self._best_ensemble_indices[iteration_number] + + return None + + +__all__ = ["Config"] diff --git a/docs/source/adanet.replay.rst b/docs/source/adanet.replay.rst new file mode 100644 index 00000000..03f68bc2 --- /dev/null +++ b/docs/source/adanet.replay.rst @@ -0,0 +1,15 @@ +.. role:: hidden + :class: hidden-section + +adanet.replay +============================== + + +.. automodule:: adanet.replay +.. currentmodule:: adanet.replay + +:hidden:`Config` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: Config + :members: