Skip to content

Commit

Permalink
Merge pull request #25 from EmanueleGhelfi/callbacks
Browse files Browse the repository at this point in the history
Callbacks
  • Loading branch information
EmanueleGhelfi committed Oct 7, 2019
2 parents 19d0213 + 41ad5c0 commit c6d93aa
Show file tree
Hide file tree
Showing 15 changed files with 327 additions and 168 deletions.
2 changes: 1 addition & 1 deletion ashpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
trainers,
)

__version__ = "0.1.3"
__version__ = "0.2.0"
__url__ = "https://github.com/zurutech/ashpy"
__author__ = "Machine Learning Team @ Zuru Tech"
__email__ = "ml@zuru.tech"
4 changes: 3 additions & 1 deletion ashpy/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from ashpy.callbacks.events import Event
from ashpy.contexts import Context

__ALL__ = ["Callback"]


class Callback(tf.Module):
r"""
Expand All @@ -26,7 +28,7 @@ class Callback(tf.Module):
Every callback must extend from this class.
This class defines the basic events.
Every event takes as input the context in order to use the objects defined.
Inheritance from tf.Module is required since callbacks have a state
Inheritance from :py:class:`tf.Module` is required since callbacks have a state
Order:
.. code-block::
Expand Down
12 changes: 8 additions & 4 deletions ashpy/callbacks/counter_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from ashpy.callbacks.events import Event
from ashpy.contexts import Context

__ALL__ = ["CounterCallback"]


class CounterCallback(Callback):
"""
Expand All @@ -52,6 +54,8 @@ def __init__(self, event: Event, fn: Callable, name: str, event_freq: int = 1):
"""
super().__init__()
self._name = name
if not isinstance(event, Event):
raise TypeError("Use the Event enum!")
self._event = event

if event_freq <= 0:
Expand All @@ -74,13 +78,13 @@ def on_event(self, event: Event, context: Context):
context (:py:class:`ashpy.contexts.context.Context`): current context.
"""
# check the event type
# Check the event type
if event == self._event:

# increment event counter
# Increment event counter
self._event_counter.assign_add(1)

# if the module between the event counter and the
# frequency is zero, call the fn
# If the module between the event counter and the
# Frequency is zero, call the fn
if tf.equal(tf.math.mod(self._event_counter, self._event_freq), 0):
self._fn(context)
4 changes: 2 additions & 2 deletions ashpy/callbacks/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ def _log_fn(self, context: GANEncoderContext):
else:
raise ValueError("Invalid LogEvalMode")

# tensorboard 2.0 does not support float images in [-1, 1]
# only in [0,1]
# Tensorboard 2.0 does not support float images in [-1, 1]
# Only in [0,1]
if generator_of_encoder.dtype == tf.float32:
# The hypothesis is that image are in [-1,1] how to check?
generator_of_encoder = (generator_of_encoder + 1.0) / 2
Expand Down
26 changes: 17 additions & 9 deletions ashpy/callbacks/save_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,15 @@ def __init__(
self._models = models
self._verbose = verbose
self._max_to_keep = max_to_keep

if not isinstance(save_format, SaveFormat):
raise TypeError("Use the SaveFormat enum!")

self._save_format = save_format

if not isinstance(save_sub_format, SaveSubFormat):
raise TypeError("Use the SaveSubFormat enum!")

self._save_sub_format = save_sub_format
self._counter = 0
self._save_path_histories = [deque() for _ in self._models]
Expand Down Expand Up @@ -211,16 +219,16 @@ def _cleanup(self):
while self._counter > self._max_to_keep:
for save_path_history in self._save_path_histories:
if len(save_path_history) >= self._max_to_keep:
# get the first element of the queue
# Get the first element of the queue
save_dir_to_remove = save_path_history.popleft()

if self._verbose:
print(f"{self._name}: Removing {save_dir_to_remove} from disk.")

# remove directory
# Remove directory
shutil.rmtree(save_dir_to_remove, ignore_errors=True)

# decrease counter
# Decrease counter
self._counter -= 1

def _save_weights_fn(self, step: int):
Expand All @@ -240,27 +248,27 @@ def _save_weights_fn(self, step: int):
f"and sub-format {self._save_sub_format.value}."
)

# create the correct directory name
# Create the correct directory name
save_dir_i = os.path.join(self._save_dir, f"model-{i}-step-{step}")

if not os.path.exists(save_dir_i):
os.makedirs(save_dir_i)

# add to the history
# Add to the history
self._save_path_histories[i].append(save_dir_i)

# save using the save_format
# Save using the save_format
self._save_format.save(
model=model, save_dir=save_dir_i, save_sub_format=self._save_sub_format
)

# increase the counter of saved files
# Increase the counter of saved files
self._counter += 1

def save_weights_fn(self, context):
"""Save weights and clean up if needed."""
# save weights phase
# Save weights phase
self._save_weights_fn(context.global_step.numpy())

# clean up phase
# Clean up phase
self._cleanup()
8 changes: 4 additions & 4 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import shutil

import pytest
import tensorflow
import tensorflow # pylint: disable=import-error

import ashpy

Expand All @@ -39,14 +39,14 @@ def adversarial_logdir():
"""Add the logdir parameter to tests."""
m_adversarial_logdir = "testlog/adversarial"

# clean before
# Clean before
if os.path.exists(m_adversarial_logdir):
shutil.rmtree(m_adversarial_logdir)
assert not os.path.exists(m_adversarial_logdir)

yield m_adversarial_logdir

# teardown
# Teardown
if os.path.exists(m_adversarial_logdir):
shutil.rmtree(m_adversarial_logdir)
assert not os.path.exists(m_adversarial_logdir)
Expand All @@ -57,7 +57,7 @@ def save_dir():
"""Add the save_dir parameter to tests."""
m_save_dir = "testlog/savedir"

# clean before
# Clean before
if os.path.exists(m_save_dir):
shutil.rmtree(m_save_dir)
assert not os.path.exists(m_save_dir)
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tensorflow==2.0.0beta1
tensorflow
sphinx
sphinx-autobuild
sphinx-rtd-theme
Expand Down
50 changes: 25 additions & 25 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
# This file is autogenerated by pip-compile
# To update, run:
#
# pip-compile requirements.in
# pip-compile --pre requirements.in
#
absl-py==0.7.1 # via tb-nightly, tensorflow
absl-py==0.8.0 # via tb-nightly, tensorflow
alabaster==0.7.12 # via sphinx
argh==0.26.2 # via sphinx-autobuild, watchdog
astor==0.8.0 # via tensorflow
babel==2.7.0 # via sphinx
certifi==2019.6.16 # via requests
certifi==2019.9.11 # via requests
chardet==3.0.4 # via doc8, requests
doc8==0.8.0
docutils==0.14 # via doc8, m2r, restructuredtext-lint, sphinx
gast==0.2.2 # via tensorflow
docutils==0.15.2 # via doc8, m2r, restructuredtext-lint, sphinx
gast==0.3.2 # via tensorflow
google-pasta==0.1.7 # via tensorflow
grpcio==1.22.0 # via tb-nightly, tensorflow
h5py==2.9.0 # via keras-applications
grpcio==1.24.0 # via tb-nightly, tensorflow
h5py==2.10.0 # via keras-applications
idna==2.8 # via requests
imagesize==1.1.0 # via sphinx
jinja2==2.10.1 # via sphinx
Expand All @@ -27,44 +27,44 @@ m2r==0.2.1
markdown==3.1.1 # via tb-nightly
markupsafe==1.1.1 # via jinja2
mistune==0.8.4 # via m2r
numpy==1.17.0rc1 # via h5py, keras-applications, keras-preprocessing, tb-nightly, tensorflow, tensorflow-hub
packaging==19.0 # via sphinx
numpy==1.17.2 # via h5py, keras-applications, keras-preprocessing, tb-nightly, tensorflow, tensorflow-hub
packaging==19.2 # via sphinx
pathtools==0.1.2 # via sphinx-autobuild, watchdog
pbr==5.4.0 # via stevedore
pbr==5.4.3 # via stevedore
port_for==0.3.1 # via sphinx-autobuild
protobuf==3.9.0 # via tb-nightly, tensorflow, tensorflow-hub
pydocstyle==4.0.0
protobuf==3.10.0rc1 # via tb-nightly, tensorflow, tensorflow-hub
pydocstyle==4.0.1
pygments==2.4.2 # via sphinx
pyparsing==2.4.0 # via packaging
pytz==2019.1 # via babel
pyyaml==5.1.1 # via sphinx-autobuild, watchdog
pyparsing==2.4.2 # via packaging
pytz==2019.2 # via babel
pyyaml==5.1.2 # via sphinx-autobuild, watchdog
requests==2.22.0 # via sphinx
restructuredtext-lint==1.3.0 # via doc8
six==1.12.0 # via absl-py, doc8, grpcio, h5py, keras-preprocessing, livereload, packaging, protobuf, stevedore, tb-nightly, tensorflow, tensorflow-hub
snowballstemmer==1.9.0 # via pydocstyle, sphinx
snowballstemmer==1.9.1 # via pydocstyle, sphinx
sphinx-autobuild==0.7.1
sphinx-autodoc-typehints==1.6.0
sphinx-autodoc-typehints==1.8.0
sphinx-rtd-theme==0.4.3
sphinx==2.1.2
sphinx==2.2.0
sphinxcontrib-applehelp==1.0.1 # via sphinx
sphinxcontrib-devhelp==1.0.1 # via sphinx
sphinxcontrib-htmlhelp==1.0.2 # via sphinx
sphinxcontrib-jsmath==1.0.1 # via sphinx
sphinxcontrib-qthelp==1.0.2 # via sphinx
sphinxcontrib-serializinghtml==1.1.3 # via sphinx
sphinxcontrib-websupport==1.1.2
stevedore==1.30.1 # via doc8
stevedore==1.31.0 # via doc8
tb-nightly==1.14.0a20190603 # via tensorflow
tensorflow-hub==0.5.0
tensorflow==2.0.0beta1
tensorflow-hub==0.6.0
tensorflow==2.0.0b1
termcolor==1.1.0 # via tensorflow
tf-estimator-nightly==1.14.0.dev2019060501 # via tensorflow
tornado==6.0.3 # via livereload, sphinx-autobuild
urllib3==1.25.3 # via requests
urllib3==1.25.6 # via requests
watchdog==0.9.0 # via sphinx-autobuild
werkzeug==0.15.4 # via tb-nightly
wheel==0.33.4 # via tb-nightly, tensorflow
werkzeug==0.16.0 # via tb-nightly
wheel==0.33.6 # via tb-nightly, tensorflow
wrapt==1.11.2 # via tensorflow

# The following packages are considered to be unsafe in a requirements file:
# setuptools==41.0.1 # via markdown, protobuf, sphinx, tb-nightly
# setuptools==41.2.0 # via markdown, protobuf, sphinx, tb-nightly
60 changes: 33 additions & 27 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
# limitations under the License.

# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config

"""
Configuration file for the Sphinx documentation builder.
This file does only contain a selection of the most common options. For a
full list see the documentation:
http://www.sphinx-doc.org/en/master/config
"""

# pylint: disable=invalid-name

# -- Path setup --------------------------------------------------------------

Expand All @@ -30,12 +34,14 @@

sys.path.insert(0, os.path.abspath(os.path.join("..", "..")))

import ashpy
import ashpy # pylint: disable=wrong-import-position

# -- Project information -----------------------------------------------------

project = "AshPy"
copyright = "2019 Zuru Tech HK Limited, All rights reserved."
copyright = ( # pylint: disable=redefined-builtin
"2019 Zuru Tech HK Limited, All rights reserved."
)
author = ashpy.__author__

# The short X.Y version
Expand Down Expand Up @@ -110,7 +116,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# exclude_patterns = []

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = None
Expand Down Expand Up @@ -165,20 +171,20 @@

# -- Options for LaTeX output ------------------------------------------------

latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# latex_elements = {
# # The paper size ('letterpaper' or 'a4paper').
# #
# # 'papersize': 'letterpaper',
# # The font size ('10pt', '11pt' or '12pt').
# #
# # 'pointsize': '10pt',
# # Additional stuff for the LaTeX preamble.
# #
# # 'preamble': '',
# # Latex figure (float) alignment
# #
# # 'figure_align': 'htbp',
# }

# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
Expand Down Expand Up @@ -237,13 +243,13 @@

# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
"python": ("https://docs.python.org/", None),
"flask": ("http://flask.pocoo.org/docs/1.0/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"python": ("https://docs.python.org/", None),
"tensorflow": (
"https://www.tensorflow.org/versions/r2.0/api_docs/python/",
"tf2_py_objects.inv",
"https://www.tensorflow.org/api_docs/python",
"https://github.com/mr-ubik/tensorflow-intersphinx/raw/master/tf2_py_objects.inv",
),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
}

# -- Options for todo extension ----------------------------------------------
Expand Down

0 comments on commit c6d93aa

Please sign in to comment.