Skip to content

Commit

Permalink
Merge pull request #33 from mr-ubik/issue-29-log-folder-bug
Browse files Browse the repository at this point in the history
Issue 29 log folder bug
  • Loading branch information
galeone committed Jan 22, 2020
2 parents 4196dbf + 491c66d commit b5d7215
Show file tree
Hide file tree
Showing 20 changed files with 287 additions and 128 deletions.
18 changes: 0 additions & 18 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,6 @@ def add_common_namespaces(doctest_namespace):
doctest_namespace["callbacks"] = ashpy.callbacks


@pytest.fixture(scope="function")
def adversarial_logdir():
"""Add the logdir parameter to tests."""
m_adversarial_logdir = "testlog/adversarial"

# Clean before
if os.path.exists(m_adversarial_logdir):
shutil.rmtree(m_adversarial_logdir)
assert not os.path.exists(m_adversarial_logdir)

yield m_adversarial_logdir

# Teardown
if os.path.exists(m_adversarial_logdir):
shutil.rmtree(m_adversarial_logdir)
assert not os.path.exists(m_adversarial_logdir)


@pytest.fixture(scope="function")
def save_dir():
"""Add the save_dir parameter to tests."""
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"sphinx.ext.graphviz",
"m2r",
"sphinx_autodoc_typehints",
"sphinx_copybutton",
]

# Autodoc
Expand Down
1 change: 1 addition & 0 deletions requirements.in/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ pydocstyle
sphinx
sphinx-autobuild
sphinx-autodoc-typehints
sphinx-copybutton
sphinx-rtd-theme
10 changes: 7 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ click==7.0 # via black
codecov==2.0.15
coverage==5.0.3 # via codecov, pytest-cov
doc8==0.8.0
docutils==0.16 # via doc8, restructuredtext-lint, sphinx
docutils==0.16 # via doc8, flit, restructuredtext-lint, sphinx
entrypoints==0.3 # via flake8
filelock==3.0.12 # via tox
flake8-bugbear==20.1.2
flake8==3.7.9
flit-core==2.2.0 # via flit
flit==2.2.0 # via sphinx-copybutton
gast==0.2.2 # via tensorflow
google-auth-oauthlib==0.4.1 # via tensorboard
google-auth==1.10.1 # via google-auth-oauthlib, tensorboard
Expand Down Expand Up @@ -68,11 +70,12 @@ pylint==2.4.4
pyparsing==2.4.6 # via packaging
pytest-cov==2.8.1
pytest==5.3.3
pytoml==0.1.21 # via flit, flit-core
pytz==2019.3 # via babel
pyyaml==5.3 # via sphinx-autobuild, watchdog
regex==2020.1.8 # via black
requests-oauthlib==1.3.0 # via google-auth-oauthlib
requests==2.22.0 # via codecov, requests-oauthlib, sphinx, tensorboard
requests==2.22.0 # via codecov, flit, requests-oauthlib, sphinx, tensorboard
restructuredtext-lint==1.3.0 # via doc8
rope==0.16.0
rsa==4.0 # via google-auth
Expand All @@ -81,6 +84,7 @@ six==1.14.0 # via absl-py, astroid, doc8, google-auth, google-past
snowballstemmer==2.0.0 # via pydocstyle, sphinx
sphinx-autobuild==0.7.1
sphinx-autodoc-typehints==1.10.3
sphinx-copybutton==0.2.8
sphinx-rtd-theme==0.4.3
sphinx==2.3.1
sphinxcontrib-applehelp==1.0.1 # via sphinx
Expand All @@ -105,7 +109,7 @@ virtualenv==16.7.9 # via tox
watchdog==0.9.0 # via sphinx-autobuild
wcwidth==0.1.8 # via pytest
werkzeug==0.16.0 # via tensorboard
wheel==0.33.6 # via tensorboard, tensorflow
wheel==0.33.6 # via sphinx-copybutton, tensorboard, tensorflow
wrapt==1.11.2 # via astroid, tensorflow
zipp==2.0.0 # via importlib-metadata

Expand Down
10 changes: 7 additions & 3 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ click==7.0 # via black
codecov==2.0.15
coverage==5.0.3 # via codecov, pytest-cov
doc8==0.8.0
docutils==0.16 # via doc8, restructuredtext-lint, sphinx
docutils==0.16 # via doc8, flit, restructuredtext-lint, sphinx
entrypoints==0.3 # via flake8
filelock==3.0.12 # via tox
flake8-bugbear==20.1.2
flake8==3.7.9
flit-core==2.2.0 # via flit
flit==2.2.0 # via sphinx-copybutton
gast==0.2.2 # via tensorflow
google-auth-oauthlib==0.4.1 # via tensorboard
google-auth==1.10.1 # via google-auth-oauthlib, tensorboard
Expand Down Expand Up @@ -68,11 +70,12 @@ pylint==2.4.4
pyparsing==2.4.6 # via packaging
pytest-cov==2.8.1
pytest==5.3.3
pytoml==0.1.21 # via flit, flit-core
pytz==2019.3 # via babel
pyyaml==5.3 # via sphinx-autobuild, watchdog
regex==2020.1.8 # via black
requests-oauthlib==1.3.0 # via google-auth-oauthlib
requests==2.22.0 # via codecov, requests-oauthlib, sphinx, tensorboard
requests==2.22.0 # via codecov, flit, requests-oauthlib, sphinx, tensorboard
restructuredtext-lint==1.3.0 # via doc8
rope==0.16.0
rsa==4.0 # via google-auth
Expand All @@ -81,6 +84,7 @@ six==1.14.0 # via absl-py, astroid, doc8, google-auth, google-past
snowballstemmer==2.0.0 # via pydocstyle, sphinx
sphinx-autobuild==0.7.1
sphinx-autodoc-typehints==1.10.3
sphinx-copybutton==0.2.8
sphinx-rtd-theme==0.4.3
sphinx==2.3.1
sphinxcontrib-applehelp==1.0.1 # via sphinx
Expand All @@ -105,7 +109,7 @@ virtualenv==16.7.9 # via tox
watchdog==0.9.0 # via sphinx-autobuild
wcwidth==0.1.8 # via pytest
werkzeug==0.16.0 # via tensorboard
wheel==0.33.6 # via tensorboard, tensorflow
wheel==0.33.6 # via sphinx-copybutton, tensorboard, tensorflow
wrapt==1.11.2 # via astroid, tensorflow
zipp==2.0.0 # via importlib-metadata

Expand Down
9 changes: 7 additions & 2 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ babel==2.8.0 # via sphinx
certifi==2019.11.28 # via requests
chardet==3.0.4 # via doc8, requests
doc8==0.8.0
docutils==0.16 # via doc8, restructuredtext-lint, sphinx
docutils==0.16 # via doc8, flit, restructuredtext-lint, sphinx
flit-core==2.2.0 # via flit
flit==2.2.0 # via sphinx-copybutton
idna==2.8 # via requests
imagesize==1.2.0 # via sphinx
jinja2==2.10.3 # via sphinx
Expand All @@ -23,14 +25,16 @@ port_for==0.3.1 # via sphinx-autobuild
pydocstyle==5.0.2
pygments==2.5.2 # via sphinx
pyparsing==2.4.6 # via packaging
pytoml==0.1.21 # via flit, flit-core
pytz==2019.3 # via babel
pyyaml==5.3 # via sphinx-autobuild, watchdog
requests==2.22.0 # via sphinx
requests==2.22.0 # via flit, sphinx
restructuredtext-lint==1.3.0 # via doc8
six==1.14.0 # via doc8, livereload, packaging, stevedore
snowballstemmer==2.0.0 # via pydocstyle, sphinx
sphinx-autobuild==0.7.1
sphinx-autodoc-typehints==1.10.3
sphinx-copybutton==0.2.8
sphinx-rtd-theme==0.4.3
sphinx==2.3.1
sphinxcontrib-applehelp==1.0.1 # via sphinx
Expand All @@ -43,6 +47,7 @@ stevedore==1.31.0 # via doc8
tornado==6.0.3 # via livereload, sphinx-autobuild
urllib3==1.25.7 # via requests
watchdog==0.9.0 # via sphinx-autobuild
wheel==0.33.6 # via sphinx-copybutton

# The following packages are considered to be unsafe in a requirements file:
# setuptools
1 change: 0 additions & 1 deletion src/ashpy/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(
self._metric = metric
self._model_selection_operator = model_selection_operator
self._logdir = logdir
self._update_logdir()

def model_selection(
self, checkpoint: tf.train.Checkpoint, global_step: tf.Variable
Expand Down
4 changes: 1 addition & 3 deletions src/ashpy/trainers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ def toy_dataset():
else:
metrics = [self._avg_loss]

for metric in metrics:
metric.logdir = self._logdir
self._metrics = metrics
super()._update_metrics(metrics)

self._checkpoint.objects.extend([self._optimizer, self._model])
self._restore_or_init()
Expand Down
8 changes: 4 additions & 4 deletions src/ashpy/trainers/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class AdversarialTrainer(Trainer):
output_shape=10,
),
model_selection_operator=operator.gt,
logdir=logdir,
)
]
trainer = trainers.gan.AdversarialTrainer(
Expand Down Expand Up @@ -194,7 +193,7 @@ def __init__(
else:
metrics = losses_metrics

self._metrics = metrics
super()._update_metrics(metrics)

self._generator_optimizer = generator_optimizer
self._discriminator_optimizer = discriminator_optimizer
Expand Down Expand Up @@ -430,7 +429,6 @@ def real_gen():
metrics.gan.EncodingAccuracy(
classifier,
# model_selection_operator=operator.gt,
logdir=logdir
)
]
Expand Down Expand Up @@ -526,6 +524,9 @@ def __init__(
track of the training steps.
"""
if not metrics:
metrics = []
metrics.append(EncoderLoss(logdir=logdir))
super().__init__(
generator=generator,
discriminator=discriminator,
Expand All @@ -547,7 +548,6 @@ def __init__(
self._encoder_loss = encoder_loss
self._encoder_loss.reduction = tf.losses.Reduction.NONE

self._metrics.append(EncoderLoss(logdir=logdir))
self._checkpoint.objects.extend([self._encoder, self._encoder_optimizer])
self._restore_or_init()

Expand Down
6 changes: 6 additions & 0 deletions src/ashpy/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ def _validate_callbacks(self):
"""Check if every callback is an :py:class:`ashpy.callbacks.Callback`."""
validate_objects(self._callbacks, Callback)

def _update_metrics(self, metrics):
if metrics:
for metric in metrics:
metric.logdir = self._logdir
self._metrics = metrics

def _update_global_batch_size(
self,
dataset: tf.data.Dataset,
Expand Down
8 changes: 3 additions & 5 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
from ashpy.callbacks.events import Event
from ashpy.callbacks.gan import LogImageGANCallback

from tests.utils.fake_training_loop import fake_training_loop
from tests.utils.fake_training_loop import fake_adversarial_training_loop


def test_callbacks(adversarial_logdir: str):
def test_callbacks(tmpdir):
"""Test the integration between callbacks and trainer."""

callbacks = [LogImageGANCallback(event=Event.ON_BATCH_END, event_freq=1)]

fake_training_loop(adversarial_logdir, callbacks=callbacks)
fake_adversarial_training_loop(tmpdir, callbacks=callbacks)
8 changes: 4 additions & 4 deletions tests/callbacks/test_counter_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ashpy.callbacks import CounterCallback, Event
from ashpy.models.gans import ConvDiscriminator, ConvGenerator

from tests.utils.fake_training_loop import fake_training_loop
from tests.utils.fake_training_loop import fake_adversarial_training_loop


class FakeCounterCallback(CounterCallback):
Expand Down Expand Up @@ -74,16 +74,16 @@ def test_counter_callback_multiple_events():


# TODO: parametrize tests following test_save_callback.py
def test_counter_callback(_models, adversarial_logdir):
def test_counter_callback(_models, tmpdir):
clbk = FakeCounterCallback(
event=Event.ON_EPOCH_END,
name="TestCounterCallback",
fn=lambda context: print("Bloop"),
)
callbacks = [clbk]
generator, discriminator = _models
fake_training_loop(
adversarial_logdir,
fake_adversarial_training_loop(
logdir=tmpdir,
callbacks=callbacks,
generator=generator,
discriminator=discriminator,
Expand Down
8 changes: 4 additions & 4 deletions tests/callbacks/test_custom_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ashpy.callbacks import Callback
from ashpy.callbacks.events import Event

from tests.utils.fake_training_loop import fake_training_loop
from tests.utils.fake_training_loop import fake_adversarial_training_loop


class MCallback(Callback):
Expand Down Expand Up @@ -55,7 +55,7 @@ def get_n_events_from_epochs(


@pytest.mark.parametrize("event", list(Event))
def test_custom_callbacks(adversarial_logdir: str, event: Event):
def test_custom_callbacks(tmpdir, event: Event):
"""Test the integration between a custom callback and a trainer."""
m_callback = MCallback(event)
callbacks = [m_callback]
Expand All @@ -64,8 +64,8 @@ def test_custom_callbacks(adversarial_logdir: str, event: Event):
dataset_size = 2
batch_size = 2

fake_training_loop(
adversarial_logdir,
fake_adversarial_training_loop(
logdir=tmpdir,
callbacks=callbacks,
epochs=epochs,
dataset_size=dataset_size,
Expand Down
29 changes: 8 additions & 21 deletions tests/callbacks/test_save_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ashpy.callbacks import SaveCallback, SaveFormat, SaveSubFormat
from ashpy.models.gans import ConvDiscriminator, ConvGenerator

from tests.utils.fake_training_loop import fake_training_loop
from tests.utils.fake_training_loop import fake_adversarial_training_loop

COMPATIBLE_FORMAT_AND_SUB_FORMAT = [
(SaveFormat.WEIGHTS, SaveSubFormat.TF),
Expand All @@ -33,15 +33,11 @@

@pytest.mark.parametrize("save_format_and_sub_format", COMPATIBLE_FORMAT_AND_SUB_FORMAT)
def test_save_callback_compatible(
adversarial_logdir: str,
save_format_and_sub_format: Tuple[SaveFormat, SaveSubFormat],
save_dir: str,
tmpdir, save_format_and_sub_format: Tuple[SaveFormat, SaveSubFormat], save_dir: str,
):
"""Test the integration between callbacks and trainer."""
save_format, save_sub_format = save_format_and_sub_format
_test_save_callback_helper(
adversarial_logdir, save_format, save_sub_format, save_dir
)
_test_save_callback_helper(tmpdir, save_format, save_sub_format, save_dir)

save_dirs = os.listdir(save_dir)
# 2 folders: generator and discriminator
Expand All @@ -58,25 +54,19 @@ def test_save_callback_compatible(
"save_format_and_sub_format", INCOMPATIBLE_FORMAT_AND_SUB_FORMAT
)
def test_save_callback_incompatible(
adversarial_logdir: str,
save_format_and_sub_format: Tuple[SaveFormat, SaveSubFormat],
save_dir: str,
tmpdir, save_format_and_sub_format: Tuple[SaveFormat, SaveSubFormat], save_dir: str,
):
"""Test the integration between callbacks and trainer."""
save_format, save_sub_format = save_format_and_sub_format

with pytest.raises(NotImplementedError):
_test_save_callback_helper(
adversarial_logdir, save_format, save_sub_format, save_dir
)
_test_save_callback_helper(tmpdir, save_format, save_sub_format, save_dir)

# assert no folder has been created
assert not os.path.exists(save_dir)


def _test_save_callback_helper(
adversarial_logdir, save_format, save_sub_format, save_dir
):
def _test_save_callback_helper(tmpdir, save_format, save_sub_format, save_dir):
image_resolution = (28, 28)
layer_spec_input_res = (7, 7)
layer_spec_target_res = (7, 7)
Expand Down Expand Up @@ -112,11 +102,8 @@ def _test_save_callback_helper(
)
]

fake_training_loop(
adversarial_logdir,
callbacks=callbacks,
generator=generator,
discriminator=discriminator,
fake_adversarial_training_loop(
tmpdir, callbacks=callbacks, generator=generator, discriminator=discriminator,
)


Expand Down

0 comments on commit b5d7215

Please sign in to comment.