Skip to content

Commit

Permalink
BUGFIX: SWD children metrics now have logdir updated from their father
Browse files Browse the repository at this point in the history
- Make CodeFactor happy
- Improve test suite
  • Loading branch information
mr-ubik committed Jan 23, 2020
1 parent 5429172 commit 1e2c79d
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 78 deletions.
22 changes: 16 additions & 6 deletions src/ashpy/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,19 @@ def model_selection(
"""
current_value = self.result()
previous_value = float(
self.json_read(self.best_model_sel_file)[self._name.replace("/", "_")]
self.json_read(self.best_model_sel_file)[self.sanitized_name]
)
# Model selection is done ONLY if an operator was passed at __init__
if self._model_selection_operator and self._model_selection_operator(
current_value, previous_value
):
tf.print(
f"{self._name.replace('/', '_')}: validation value: {previous_value}{current_value}"
f"{self.sanitized_name}: validation value: {previous_value}{current_value}"
)
Metric.json_write(
self.best_model_sel_file,
{
self._name.replace("/", "_"): str(current_value),
self.sanitized_name: str(current_value),
"step": int(global_step.numpy()),
},
)
Expand All @@ -114,14 +114,24 @@ def _update_logdir(self):
)
self.json_write(
self.best_model_sel_file,
{self._name.replace("/", "_"): str(initial_value), "step": 0},
{self.sanitized_name: str(initial_value), "step": 0},
)

@property
def name(self) -> str:
"""Retrieve the metric name."""
return self._name

@property
def sanitized_name(self) -> str:
"""
Retrieve the sanitized name: all / are _.
This is done since adding a prefix to a metric name with a / allows for TensorBoard
automatic grouping. When we are not working with TB we want to replace all / with _.
"""
return self._name.replace("/", "_")

@property
def metric(self) -> tf.keras.metrics.Metric:
"""Retrieve the :py:class:`tf.keras.metrics.Metric` object."""
Expand All @@ -146,12 +156,12 @@ def logdir(self, logdir) -> None:
@property
def best_folder(self) -> str:
"""Retrieve the folder used to save the best model when doing model selection."""
return os.path.join(self.logdir, "best", self._name.replace("/", "_"))
return os.path.join(self.logdir, "best", self.sanitized_name)

@property
def best_model_sel_file(self) -> str:
"""Retrieve the path to JSON file containing the measured performance of the best model."""
return os.path.join(self.best_folder, self._name.replace("/", "_") + ".json")
return os.path.join(self.best_folder, self.sanitized_name + ".json")

@staticmethod
def json_read(filename: str) -> Dict[str, Any]:
Expand Down
13 changes: 13 additions & 0 deletions src/ashpy/metrics/sliced_wasserstein_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,19 @@ def __init__(
)
]

@property
def logdir(self) -> str:
"""Retrieve the log directory."""
return self._logdir

@logdir.setter
def logdir(self, logdir) -> None:
"""Set the logdir changing also other properties."""
self._logdir = logdir
self._update_logdir()
for child_metric_real, child_metric_fake in self.children_real_fake:
child_metric_real.logdir, child_metric_fake.logdir = logdir, logdir

def update_state(self, context: GANContext) -> None:
"""
Update the internal state of the metric, using the information from the context object.
Expand Down
7 changes: 6 additions & 1 deletion src/ashpy/trainers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def __init__(
Examples:
.. testcode::
import shutil
import operator
import shutil
import pathlib
from ashpy.metrics import ClassifierMetric
from ashpy.trainers.classifier import ClassifierTrainer
from ashpy.losses.classifier import ClassifierLoss
Expand All @@ -84,6 +85,9 @@ def toy_dataset():
logdir = "testlog"
epochs = 2
if pathlib.Path(logdir).exists():
shutil.rmtree(logdir)
metrics = [
ClassifierMetric(tf.metrics.Accuracy()),
ClassifierMetric(tf.metrics.BinaryAccuracy()),
Expand All @@ -97,6 +101,7 @@ def toy_dataset():
logdir=logdir)
train, validation = toy_dataset(), toy_dataset()
trainer(train, validation)
shutil.rmtree(logdir)
.. testoutput::
Expand Down
4 changes: 2 additions & 2 deletions src/ashpy/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def _validate_metrics(self):
validate_objects(self._metrics, Metric)
buffer = []
for metric in self._metrics:
if metric._name in buffer:
if metric.sanitized_name in buffer:
raise ValueError("Metric should have unique names.")
buffer.append(metric._name)
buffer.append(metric.sanitized_name)

def _validate_callbacks(self):
"""Check if every callback is an :py:class:`ashpy.callbacks.Callback`."""
Expand Down
90 changes: 27 additions & 63 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
Test Metrics with the various trainers.
TODO: Adversarial Encoder
TODO: Adversarial Encoder Traner
"""
import json
import operator
Expand All @@ -37,7 +37,7 @@

from tests.utils.fake_training_loop import (
fake_adversarial_training_loop,
fake_classifier_taining_loop,
fake_classifier_training_loop,
)

DEFAULT_LOGDIR = "log"
Expand All @@ -49,6 +49,7 @@
"layer_spec_input_res": (8, 8),
"layer_spec_target_res": (8, 8),
"channels": 3,
"measure_performance_freq": 1,
},
[
SlicedWassersteinDistance(resolution=256),
Expand All @@ -67,15 +68,17 @@
],
],
"classifier_trainer": [
fake_classifier_taining_loop,
{},
fake_classifier_training_loop,
{"measure_performance_freq": 1},
[ClassifierLoss(model_selection_operator=operator.lt)],
],
}

TEST_PARAMS_METRICS_LOG = [TEST_MATRIX_METRICS_LOG[k] for k in TEST_MATRIX_METRICS_LOG]
TEST_IDS_METRICS_LOG = [k for k in TEST_MATRIX_METRICS_LOG]

OPERATOR_INITIAL_VALUE_MAP = {operator.gt: "-inf", operator.lt: "inf"}


@pytest.fixture(scope="module")
def cleanup():
Expand All @@ -97,12 +100,15 @@ def test_metrics_log(training_loop, loop_args, metrics, tmpdir, cleanup):
"""
Test that trainers correctly create metrics log files.
Also test that model selection has been correctly performed.
GIVEN a correctly instantiated trainer
GIVEN some training has been done
THEN there should not be any logs inside the default log folder
THEN there exists a logdir folder for each metric
THEN there inside in each folder there's the JSON file w/ the metric logs
THEN in the file there are the correct keys
THEN in the file there are the correct keys'
THEN the values of the keys should not be the operator initial value
"""
training_completed, trainer = training_loop(
Expand All @@ -113,77 +119,35 @@ def test_metrics_log(training_loop, loop_args, metrics, tmpdir, cleanup):
assert not pathlib.Path(DEFAULT_LOGDIR).exists() # Assert absence of side effects
# Assert there exists folder for each metric
for metric in trainer._metrics:
metric_dir = pathlib.Path(tmpdir).joinpath(
"best", metric.name.replace("/", "_")
)
metric_dir = pathlib.Path(tmpdir).joinpath("best", metric.sanitized_name)
assert metric_dir.exists()
json_path = metric_dir.joinpath(f"{metric.name.replace('/', '_')}.json")
json_path = metric_dir.joinpath(f"{metric.sanitized_name}.json")
assert json_path.exists()
with open(json_path, "r") as fp:
metric_data = json.load(fp)

# Assert the metric data contains the expected keys
assert metric.name.replace("/", "_") in metric_data
assert metric.sanitized_name in metric_data
assert "step" in metric_data


# -------------------------------------------------------------------------------------


TEST_MATRIX_MODEL_SELECTION = {
# TODO: Add test for a metric with operator.gt
"classifier_trainer_lt": [
fake_classifier_taining_loop,
{},
[ClassifierLoss(model_selection_operator=operator.lt)],
operator.lt,
],
}
TEST_PARAMS_MODEL_SELECTION = [
TEST_MATRIX_MODEL_SELECTION[k] for k in TEST_MATRIX_MODEL_SELECTION
]
TEST_IDS_MODEL_SELECTION = [k for k in TEST_MATRIX_MODEL_SELECTION]


@pytest.mark.parametrize(
["training_loop", "loop_args", "metrics", "operator_check"],
TEST_PARAMS_MODEL_SELECTION,
ids=TEST_IDS_MODEL_SELECTION,
)
def test_model_selection(
training_loop, loop_args, metrics, operator_check: List[Metric]
):
"""
Test the correct model selection behaviour of metrics.
Model selection is handled by the Metric when triggered by a trainer.
GIVEN a correctly instantiated trainer
GIVEN some training has been done
THEN there should be a metric log file containing two values
GIVEN all metrics log get initialized at -inf at step 0
GIVEN a new data point is added to the log
WHEN performing model solection this should be the value used
"""
number_of_metrics = len(metrics)
for metric in metrics:
assert metric._model_selection_operator == operator_check

training_completed, trainer = training_loop(
logdir="testlog", metrics=metrics, **loop_args
)

# Maybe make it explicit when a trainer popultate metrics autonomously
assert training_completed
# Manually have to check this in the log. Find a better way.
# loss: validation value: inf → 0.0
# Assert that the correct model selection has been performed
# Check it by seeing if the values in the json has been updated
if metric.model_selection_operator:
try:
initial_value = OPERATOR_INITIAL_VALUE_MAP[
metric.model_selection_operator
]
except KeyError:
raise ValueError(
"Please add the initial value for this operator to OPERATOR_INITIAL_VALUE_MAP"
)
assert metric_data[metric.sanitized_name] != initial_value


# -------------------------------------------------------------------------------------


@pytest.mark.parametrize("training_loop", [fake_classifier_taining_loop])
@pytest.mark.parametrize("training_loop", [fake_classifier_training_loop])
def test_metrics_names_collision(training_loop, tmpdir):
"""
Test that an exception is correctly raised when two metrics have the same name.
Expand Down
15 changes: 9 additions & 6 deletions tests/utils/fake_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""Fake training loop to simplify training in tests."""
import operator
from pathlib import Path

import ashpy
import tensorflow as tf
Expand All @@ -26,9 +25,9 @@
from tests.utils.fake_models import conv_autoencoder


def fake_classifier_taining_loop(
def fake_classifier_training_loop(
# Trainer
logdir: Path,
logdir: str = "testlog",
optimizer=tf.optimizers.Adam(1e-4),
metrics=[ashpy.metrics.ClassifierLoss(model_selection_operator=operator.lt)],
epochs=2,
Expand All @@ -44,6 +43,8 @@ def fake_classifier_taining_loop(
filters_cap=64,
encoding_dimension=50,
channels=3,
# Call parameters
measure_performance_freq=10,
):
"""Fake Classifier training loop implementation using an autoencoder as a base model."""
# Model
Expand Down Expand Up @@ -77,12 +78,12 @@ def fake_classifier_taining_loop(
metrics=metrics,
)

trainer(dataset, dataset)
trainer(dataset, dataset, measure_performance_freq=measure_performance_freq)
return 1, trainer


def fake_adversarial_training_loop(
logdir,
logdir: str = "testlog",
generator=None,
discriminator=None,
metrics=None,
Expand All @@ -96,6 +97,8 @@ def fake_adversarial_training_loop(
layer_spec_input_res=(7, 7),
layer_spec_target_res=(7, 7),
channels=1,
# Call parameters
measure_performance_freq=10,
):
"""Fake training loop implementation."""
# test parameters
Expand Down Expand Up @@ -162,5 +165,5 @@ def fake_adversarial_training_loop(
lambda x, y: ((x, y), tf.random.normal(shape=(batch_size, latent_dim)))
)

trainer(dataset)
trainer(dataset, measure_performance_freq=measure_performance_freq)
return 1, trainer

0 comments on commit 1e2c79d

Please sign in to comment.