Skip to content

Commit

Permalink
Merge pull request #44 from mr-ubik/42-model-selection-json-overwritt…
Browse files Browse the repository at this point in the history
…en-on-restart

Fix Trainer restart issues
  • Loading branch information
galeone committed Feb 7, 2020
2 parents 7dc0b40 + 7efd7dd commit 8320af1
Show file tree
Hide file tree
Showing 19 changed files with 555 additions and 299 deletions.
15 changes: 9 additions & 6 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
SSIM_Multiscale,
)
from tests.utils.fake_training_loop import (
fake_adversarial_training_loop,
fake_classifier_training_loop,
FakeAdversarialTraining,
FakeClassifierTraining,
)


Expand Down Expand Up @@ -70,7 +70,7 @@ def save_dir():
TEST_MATRIX = {
# NOTE: Always pass metrics as Tuple, Trainers produce side effects!
"adversarial_trainer": [
fake_adversarial_training_loop,
FakeAdversarialTraining,
{
"image_resolution": [256, 256],
"layer_spec_input_res": (8, 8),
Expand Down Expand Up @@ -101,7 +101,7 @@ def save_dir():
),
],
"classifier_trainer": [
fake_classifier_training_loop,
FakeClassifierTraining,
{"measure_performance_freq": 1},
(ClassifierLoss(model_selection_operator=operator.lt),),
],
Expand All @@ -112,8 +112,11 @@ def save_dir():


@pytest.fixture(scope="function", params=LOOPS, ids=TRAINING_IDS)
def fake_training(request):
def fake_training_fn(request):
"""Fixture used to generate fake training for the tests."""
training_loop, loop_args, metrics = request.param
assert len(metrics) in [1, 3]
return (training_loop, loop_args, list(metrics))

return lambda logdir, **kwargs: training_loop(
logdir=logdir, metrics=metrics, **loop_args, **kwargs
)
4 changes: 2 additions & 2 deletions src/ashpy/contexts/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional, Tuple

import tensorflow as tf # pylint: disable=import-error
from ashpy.contexts.context import Context
Expand All @@ -36,7 +36,7 @@ def __init__(
classifier_model: tf.keras.Model = None,
loss: ClassifierLoss = None, # ?: Do we really need to default these values to None?
dataset: tf.data.Dataset = None,
metrics: List[Metric] = None,
metrics: Tuple[Metric] = None,
log_eval_mode: LogEvalMode = LogEvalMode.TEST,
global_step: tf.Variable = tf.Variable(
0, name="global_step", trainable=False, dtype=tf.int64
Expand Down
13 changes: 6 additions & 7 deletions src/ashpy/contexts/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
handle information transfer.
"""

from typing import List, Optional
from typing import Optional, Tuple

import tensorflow as tf
from ashpy.metrics import Metric
Expand All @@ -32,7 +32,7 @@ class Context:

def __init__(
self,
metrics: List[Metric] = None,
metrics: Tuple[Metric] = None,
dataset: tf.data.Dataset = None,
log_eval_mode: LogEvalMode = LogEvalMode.TEST,
global_step=tf.Variable(0, name="global_step", trainable=False, dtype=tf.int64),
Expand All @@ -42,7 +42,7 @@ def __init__(
Initialize the Context.
Args:
metrics (:obj:`list` of [:py:class:`ashpy.metrics.metric.Metric`]): List of
metrics (:obj:`tuple` of (:py:class:`ashpy.metrics.metric.Metric`)): List of
:py:class:`ashpy.metrics.metric.Metric` objects.
dataset (:py:class:`tf.data.Dataset`): The dataset to use, that
contains everything needed to use the model in this context.
Expand All @@ -55,8 +55,7 @@ def __init__(
"""
self._distribute_strategy = tf.distribute.get_strategy()

# TODO: are metrics really needed right now?
self._metrics = metrics if metrics else []
self._metrics = metrics if metrics else ()
self._dataset = dataset
self._log_eval_mode = log_eval_mode
self._global_step = global_step
Expand Down Expand Up @@ -98,12 +97,12 @@ def dataset(self, _dataset: tf.data.Dataset):
self._dataset = _dataset

@property
def metrics(self) -> List[Metric]:
def metrics(self) -> Tuple[Metric]:
"""
Retrieve the metrics.
Returns:
:obj:`list` of [:py:class:`ashpy.metrics.metric.Metric`].
:obj:`tuple` of (:py:class:`ashpy.metrics.metric.Metric`).
"""
return self._metrics
Expand Down
14 changes: 7 additions & 7 deletions src/ashpy/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ def _update_logdir(self):
# write the initial value of the best metric
if not self.best_model_sel_file.exists():
self.best_model_sel_file.parent.mkdir(parents=True)
initial_value = (
np.inf if self._model_selection_operator is operator.lt else -np.inf
)
self.json_write(
self.best_model_sel_file,
{self.sanitized_name: str(initial_value), "step": 0},
)
initial_value = (
np.inf if self._model_selection_operator is operator.lt else -np.inf
)
self.json_write(
self.best_model_sel_file,
{self.sanitized_name: str(initial_value), "step": 0},
)

@property
def name(self) -> str:
Expand Down
3 changes: 2 additions & 1 deletion src/ashpy/restorers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@

from ashpy.restorers.classifier import ClassifierRestorer
from ashpy.restorers.gan import AdversarialEncoderRestorer, AdversarialRestorer
from ashpy.restorers.restorer import Restorer
from ashpy.restorers.restorer import ModelNotConstructedError, Restorer

__ALL__ = [
"Restorer",
"AdversarialRestorer",
"AdversarialEncoderRestorer",
"ClassifierRestorer",
"ModelNotConstructedError",
]
40 changes: 39 additions & 1 deletion src/ashpy/restorers/restorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,22 @@
import ashpy
import tensorflow as tf

__ALL__ = ["Restorer"]
__ALL__ = ["Restorer, ModelNotConstructedError"]


class ModelNotConstructedError(Exception):
"""
Exception raised while restoring sub-classed Model before having called it on data.
Warning:
When restoring a :class:`tf.keras.Model` object from checkpoint assure that the
model has been correctly built and instantiated by firstly calling it on some
sample inputs. In the case of a model built with either the Sequential or
Functional API an exception will be raised; for a model built with the Chainer API
it will fail silently, restoration will be "successful" but no values will actually
be restored since there are no valid placeholder as the model has not be built yet.
"""


class Restorer:
Expand Down Expand Up @@ -95,6 +110,27 @@ def _validate_placeholder(placeholder: List, placeholder_type):
f"Object {placeholder} is should be of type: {placeholder_type}"
)

@staticmethod
def _check_model_construction(restored_model: tf.keras.Model) -> bool:
"""
Optimistically check that the model.weights property returns a non empty-list.
The underlying assumption is that Models created via the sub-classing API, when restored
without being properly constructed AKA called on some input, will have empty lists
as layers.weights.
TODO: add docs for the exception.
TODO: add test case for the Sequential without input shape
"""
try:
if restored_model.weights == []:
raise ModelNotConstructedError
except AttributeError:
# A Sequential() buil without specifiyng the input shape can be treated as a
# sub-classed model for restoration purposes.
raise ModelNotConstructedError
return True

def restore_object(self, placeholder, object_ckpt_id: str):
"""
Restore a placeholder from a checkpoint using the specified id.
Expand All @@ -112,6 +148,8 @@ def restore_object(self, placeholder, object_ckpt_id: str):
"""
checkpoint = tf.train.Checkpoint(**{object_ckpt_id: placeholder})
status = self._restore_checkpoint(checkpoint)
if isinstance(placeholder, tf.keras.Model):
assert self._check_model_construction(placeholder)
print(self._restored_log_msg.format(object_ckpt_id, self._ckpts_dir))
return status

Expand Down
21 changes: 16 additions & 5 deletions src/ashpy/trainers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Primitive Trainer Interface."""

from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import ashpy
import tensorflow as tf
Expand All @@ -41,7 +41,7 @@ def __init__(
optimizer: tf.optimizers.Optimizer,
loss: ashpy.losses.ClassifierLoss,
epochs: int,
metrics: Optional[List[Metric]] = None,
metrics: Optional[Union[Tuple[Metric], List[Metric]]] = None,
callbacks: Optional[List[Callback]] = None,
logdir: Union[Path, str] = Path().cwd() / "log",
global_step: Optional[tf.Variable] = None,
Expand All @@ -56,7 +56,7 @@ def __init__(
loss (:obj:`ashpy.losses.classifier.ClassifierLoss`): A loss function built following
:py:mod:`ashpy.executors``.
epochs (int): Number of training epochs.
metrics: (List): List of :py:class:`ashpy.metrics.metric.Metric` to
metrics: (Tuple/List): Tuple/List of :py:class:`ashpy.metrics.metric.Metric` to
measure on training and validation data.
callbacks (List): List of :py:class:`ashpy.callbacks.callback.Callback` to
to call on events
Expand Down Expand Up @@ -136,9 +136,9 @@ def toy_dataset():

self._avg_loss = ClassifierLoss(name="ashpy/avg_loss")
if metrics:
metrics.append(self._avg_loss)
metrics = (*metrics, self._avg_loss)
else:
metrics = [self._avg_loss]
metrics = (self._avg_loss,)

super()._update_metrics(metrics)
super()._validate_metrics()
Expand All @@ -160,6 +160,14 @@ def toy_dataset():
checkpoint=self._checkpoint,
)

def _build_and_restore_models(self, dataset: tf.data.Dataset):
restorer = ashpy.restorers.ClassifierRestorer(self._logdir)
(x, _) = next(iter(dataset.take(1)))
# Invoke model on sample input
self._model(x)
restorer.restore_model(self._model)
self._deferred_restoration = False

def train_step(self, features, labels):
"""
Train step.
Expand Down Expand Up @@ -212,6 +220,9 @@ def call(
performance.
"""
if self._deferred_restoration:
self._build_and_restore_models(dataset=training_set)

# set the context properties
self._context.training_set = training_set
self._context.validation_set = validation_set
Expand Down
38 changes: 33 additions & 5 deletions src/ashpy/trainers/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

"""Collection of GANs trainers."""
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import ashpy.restorers
import tensorflow as tf
from ashpy.callbacks import Callback
from ashpy.contexts.gan import GANContext, GANEncoderContext
Expand Down Expand Up @@ -138,7 +139,7 @@ def __init__(
generator_loss: Executor,
discriminator_loss: Executor,
epochs: int,
metrics: Optional[List[Metric]] = None,
metrics: Optional[Union[Tuple[Metric], List[Metric]]] = None,
callbacks: Optional[List[Callback]] = None,
logdir: Union[Path, str] = Path().cwd() / "log",
log_eval_mode: LogEvalMode = LogEvalMode.TEST,
Expand Down Expand Up @@ -191,12 +192,12 @@ def __init__(
self._discriminator_loss = discriminator_loss
self._discriminator_loss.reduction = tf.losses.Reduction.NONE

losses_metrics = [
losses_metrics = (
DiscriminatorLoss(name="ashpy/d_loss", logdir=logdir),
GeneratorLoss(name="ashpy/g_loss", logdir=logdir),
]
)
if metrics:
metrics.extend(losses_metrics)
metrics = (*metrics, *losses_metrics)
else:
metrics = losses_metrics

Expand Down Expand Up @@ -229,6 +230,16 @@ def __init__(
metrics=self._metrics,
)

def _build_and_restore_models(self, dataset: tf.data.Dataset):
restorer = ashpy.restorers.AdversarialRestorer(self._logdir)
(x, _), z = next(iter(dataset.take(1)))
# Invoke model on sample input
self._generator(z)
self._discriminator(x)
restorer.restore_generator(self._generator)
restorer.restore_discriminator(self._discriminator)
self._deferred_restoration = False

def train_step(self, real_xy, g_inputs):
"""
Train step for the AdversarialTrainer.
Expand Down Expand Up @@ -317,6 +328,9 @@ def call(
performance.
"""
if self._deferred_restoration:
self._build_and_restore_models(dataset=dataset)

current_epoch = self._current_epoch()

self._update_global_batch_size(
Expand Down Expand Up @@ -575,6 +589,17 @@ def __init__(
metrics=self._metrics,
)

def _build_and_restore_models(self, dataset: tf.data.Dataset):
restorer = ashpy.restorers.AdversarialEncoderRestorer(self._logdir)
(x, _), _ = next(iter(dataset.take(1)))

# Invoke model on sample input
self._encoder(x)
restorer.restore_encoder(self._encoder)

super()._build_and_restore_models(dataset)
self._deferred_restoration = False

def train_step(self, real_xy, g_inputs):
"""Adversarial training step.
Expand Down Expand Up @@ -668,6 +693,9 @@ def call(
performance.
"""
if self._deferred_restoration:
self._build_and_restore_models(dataset=dataset)

current_epoch = self._current_epoch()

self._update_global_batch_size(
Expand Down

0 comments on commit 8320af1

Please sign in to comment.