Skip to content

Commit

Permalink
Merge pull request #36 from mr-ubik/30-model-restorer
Browse files Browse the repository at this point in the history
Add easy Checkpoint restoration
  • Loading branch information
galeone committed Jan 28, 2020
2 parents 2dc9aac + 66ed7a9 commit 56f7a26
Show file tree
Hide file tree
Showing 26 changed files with 1,075 additions and 142 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ good-names=_,
ex,
Ex,
f,
fn
fn,
fp,
G,
g,
Expand Down
1 change: 1 addition & 0 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ formats: all
python:
version: 3.7
install:
- requirements: requirements/base.txt
- requirements: requirements/docs.txt
system_packages: true
66 changes: 66 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,25 @@
# limitations under the License.

"""pytest configuration."""

import operator
import os
import shutil

import pytest
import tensorflow # pylint: disable=import-error

import ashpy
from ashpy.metrics import (
ClassifierLoss,
InceptionScore,
SlicedWassersteinDistance,
SSIM_Multiscale,
)
from tests.utils.fake_training_loop import (
fake_adversarial_training_loop,
fake_classifier_training_loop,
)


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -50,3 +62,57 @@ def save_dir():
if os.path.exists(m_save_dir):
shutil.rmtree(m_save_dir)
assert not os.path.exists(m_save_dir)


# ------------------------------------------------------------------------------------

TEST_MATRIX = {
# NOTE: Always pass metrics as Tuple, Trainers produce side effects!
"adversarial_trainer": [
fake_adversarial_training_loop,
{
"image_resolution": [256, 256],
"layer_spec_input_res": (8, 8),
"layer_spec_target_res": (8, 8),
"channels": 3,
"output_shape": 1,
"measure_performance_freq": 1,
"callbacks": [
ashpy.callbacks.LogImageGANCallback(
event=ashpy.callbacks.Event.ON_BATCH_END, event_freq=1
)
],
},
(
SlicedWassersteinDistance(resolution=256),
SSIM_Multiscale(),
InceptionScore(
# Fake inception model
ashpy.models.gans.ConvDiscriminator(
layer_spec_input_res=(299, 299),
layer_spec_target_res=(7, 7),
kernel_size=(5, 5),
initial_filters=16,
filters_cap=32,
output_shape=10,
)
),
),
],
"classifier_trainer": [
fake_classifier_training_loop,
{"measure_performance_freq": 1},
(ClassifierLoss(model_selection_operator=operator.lt),),
],
}

TRAINING_IDS = [k for k in TEST_MATRIX]
LOOPS = [TEST_MATRIX[k] for k in TEST_MATRIX]


@pytest.fixture(scope="function", params=LOOPS, ids=TRAINING_IDS)
def fake_training(request):
"""Fixture used to generate fake training for the tests."""
training_loop, loop_args, metrics = request.param
assert len(metrics) in [1, 3]
return (training_loop, loop_args, list(metrics))
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ API Reference
ashpy.metrics
ashpy.models
ashpy.modes
ashpy.restorers
ashpy.trainers

12 changes: 11 additions & 1 deletion docs/source/dependencies_graph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ Convolutional
GANs
====

GANs models are just aliases.

.. inheritance-diagram:: ashpy.models.gans
:parts: 1

Expand All @@ -42,7 +44,7 @@ ashpy.trainers
Adversarial
===========

.. inheritance-diagram:: ashpy.trainers.base_trainer ashpy.trainers.gan
.. inheritance-diagram:: ashpy.trainers.gan
:parts: 1

----
Expand All @@ -55,6 +57,14 @@ Classifier

----

ashpy.restorers
***************

.. inheritance-diagram:: ashpy.restorers.restorer ashpy.restorers.gan ashpy.restorers.classifier
:parts: 1

----

ashpy.layers
************

Expand Down
1 change: 1 addition & 0 deletions src/ashpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
losses,
metrics,
models,
restorers,
trainers,
)

Expand Down
19 changes: 19 additions & 0 deletions src/ashpy/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@ class Callback(tf.Module):
"""

def __init__(self, name: str) -> None:
"""
Initialize the Callback.
Args:
name (str): Callback name.
Warning:
When using multiple callbacks with the same trainer make sure they have
different ids.
"""
self._name = name

@property
def name(self):
"""Return the name of the callback."""
return self._name

def on_event(self, event: Event, context: Context) -> None:
"""
Handle the on_event event.
Expand Down
2 changes: 1 addition & 1 deletion src/ashpy/callbacks/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class LogClassifierCallback(CounterCallback):
def __init__(
self,
event: Event = Event.ON_EPOCH_END,
name="LogClassifierCallback",
name="log_classifier_callback",
event_freq: int = 1,
):
"""
Expand Down
7 changes: 4 additions & 3 deletions src/ashpy/callbacks/counter_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class CounterCallback(Callback):
you can just inherit from CounterCallback.
"""

def __init__(self, event: Event, fn: Callable, name: str, event_freq: int = 1):
def __init__(
self, event: Event, fn: Callable, name: str, event_freq: int = 1
) -> None:
"""
Initialize the CounterCallback.
Expand All @@ -51,8 +53,7 @@ def __init__(self, event: Event, fn: Callable, name: str, event_freq: int = 1):
ValueError: if `event_freq` is not valid.
"""
super().__init__()
self._name = name
super().__init__(name=name)
if not isinstance(event, Event):
raise TypeError("Use the Event enum!")
self._event = event
Expand Down
10 changes: 5 additions & 5 deletions src/ashpy/callbacks/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ class LogImageGANCallback(CounterCallback):
def __init__(
self,
event: Event = Event.ON_EPOCH_END,
name="LogImageGANCallback",
name: str = "log_image_gan_callback",
event_freq: int = 1,
):
) -> None:
"""
Initialize the LogImageCallbackGAN.
Expand Down Expand Up @@ -250,11 +250,11 @@ def real_gen():
def __init__(
self,
event: Event = Event.ON_EPOCH_END,
name="LogImageGANEncoderCallback",
name: str = "log_image_gan_encoder_callback",
event_freq: int = 1,
):
) -> None:
"""
Initialize the LogImageCallbackGAN.
Initialize the LogImageGANEncoderCallback.
Args:
event (:py:class:`ashpy.callbacks.events.Event`): event to consider.
Expand Down
55 changes: 55 additions & 0 deletions src/ashpy/restorers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2020 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Restorers allow for easy restoration of tracked objects from :class:`tf.train.Checkpoint`.
.. currentmodule:: ashpy.restorers
.. rubric:: Classes
.. autosummary::
:nosignatures:
:toctree: restorers
Restorer
AdversarialRestorer
AdversarialEncoderRestorer
ClassifierRestorer
----
.. rubric:: Modules
.. autosummary::
:nosignatures:
:toctree: restorers
:template: autosummary/submodule.rst
restorer
classifier
gan
"""

from ashpy.restorers.classifier import ClassifierRestorer
from ashpy.restorers.gan import AdversarialEncoderRestorer, AdversarialRestorer
from ashpy.restorers.restorer import Restorer

__ALL__ = [
"Restorer",
"AdversarialRestorer",
"AdversarialEncoderRestorer",
"ClassifierRestorer",
]
65 changes: 65 additions & 0 deletions src/ashpy/restorers/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2020 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convenience :class:`Restorer` to be used with :mod:`ashpy.trainers.classifier` ."""

import tensorflow as tf
from ashpy.restorers.restorer import Restorer
from ashpy.trainers import ClassifierTrainer

__ALL__ = ["ClassifierRestorer"]


class ClassifierRestorer(Restorer):
"""Convenience :class:`Restorer` for ease of use with the :class:`ClassifierTrainer`."""

def restore_model(self, model: tf.keras.Model) -> tf.keras.Model:
"""
Restore the Classifier model.
Args:
model (:class:`tf.keras.Model`): The placeholder model in which values from the
checkpoint will be restored.
Returns:
Restored model.
Warning:
When restoring a :class:`tf.keras.Model` object from checkpoint assure that the
model has been correctly built and instantiated by firstly calling it on some
sample inputs. In the case of a model built with either the Sequential or
Functional API an exception will be raised; for a model built with the Chainer API
it will fail silently, restoration will be "successful" but no values will actually
be restored since there are no valid placeholder as the model has not be built yet.
"""
self.restore_object(model, ClassifierTrainer.ckpt_id_model)
return model

def restore_optimizer(
self, optimizer: tf.keras.optimizers.Optimizer
) -> tf.keras.optimizers.Optimizer:
"""
Restore the Optimizer used to train the Classifier model.
Args:
model (:class:`tf.keras.optimizers.Optimizer`): The placeholder Optimizer in
which values from the checkpoint will be restored.
Returns:
Restored optimizer.
"""
self.restore_object(optimizer, ClassifierTrainer.ckpt_id_optimizer)
return optimizer

0 comments on commit 56f7a26

Please sign in to comment.