Skip to content

Commit

Permalink
Merge pull request #51 from mr-ubik/43-precision-model-selection
Browse files Browse the repository at this point in the history
Fix metric value float precision issues during model selection
  • Loading branch information
galeone committed Feb 11, 2020
2 parents df17d72 + 12c9ef7 commit d9aff79
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/ashpy/metrics/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def inception_score(self, images: tf.Tensor) -> tf.Tensor:
:obj:`tuple` of (:py:class:`numpy.ndarray`, :py:class:`numpy.ndarray`): Mean and STD.
"""
tf.print("Computing inception score...")
print("Computing inception score...")

predictions: tf.Tensor = self._inception_model(images)

Expand Down
10 changes: 6 additions & 4 deletions src/ashpy/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ def model_selection(
self.json_read(self.best_model_sel_file)[self.sanitized_name]
)
# Model selection is done ONLY if an operator was passed at __init__
if self._model_selection_operator and self._model_selection_operator(
current_value, previous_value
if (
self._model_selection_operator
and self._model_selection_operator(current_value, previous_value)
and not np.isclose(current_value, previous_value)
):
tf.print(
f"{self.sanitized_name}: validation value: {previous_value}{current_value}"
print(
f"{self.sanitized_name}: validation value: {previous_value}{current_value}",
)
self.json_write(
self.best_model_sel_file,
Expand Down
2 changes: 1 addition & 1 deletion src/ashpy/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
# set and validate metrics
if metrics is None:
metrics = ()
self._metrics = metrics
self._metrics: Tuple[Optional[Metric]] = metrics
self._validate_metrics()

# set and validate callbacks
Expand Down
56 changes: 55 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from pathlib import Path

import pytest
from ashpy.metrics import ClassifierLoss
import tensorflow as tf
from ashpy.metrics import ClassifierLoss, Metric

from tests.utils.fake_training_loop import FakeAdversarialTraining, FakeTraining

Expand Down Expand Up @@ -143,3 +144,56 @@ def test_metrics_on_restart(fake_training_fn, tmpdir):
}
print(t2_values)
assert t1_values == t2_values


# -------------------------------------------------------------------------------------


def test_metric_precision(fake_training_fn, tmpdir, capsys):
"""
Test that we correctly handle float precision issue.
GIVEN a metric with constant value
THEN the metric model selection log ("metric.name: validation value:") should be
present exactly once in the captured stdout.
If the `np.close` clause is removed from `Metric.model_selection()` than working with float
will make it so that extremely small values's variances due to floating point precision
trigger the model selection multiple times unneccessarily.
"""

class FakeMetric(Metric):
"""Fake Metric returning Pi as a constant."""

def __init__(self, name="fake_metric", model_selection_operator=operator.gt):
super().__init__(
name=name,
model_selection_operator=model_selection_operator,
metric=tf.metrics.Mean(name=name, dtype=tf.float32),
)
self.fake_score = (
tf.divide(
tf.exp(tf.random.normal((100,))),
(
tf.add(
tf.exp(tf.random.normal((100,))),
tf.exp(tf.random.normal((100,), 10)),
)
),
)
/ 10000
).numpy()[0]
print("FAKE SCORE: ", self.fake_score)

def update_state(self, context):
updater = lambda value: lambda: self._metric.update_state(value)
self._distribute_strategy.experimental_run_v2(updater(self.fake_score))

fake_training: FakeTraining = fake_training_fn(tmpdir)
fake_training.metrics = (*fake_training.metrics, FakeMetric())
fake_training.epochs = 5
fake_training.build_trainer()
assert fake_training()
out, _ = capsys.readouterr()
assert out.count("fake_metric: validation value:") == 1
56 changes: 37 additions & 19 deletions tests/utils/fake_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,14 @@ def __init__(
measure_performance_freq=10,
):
"""Fake Classifier training loop implementation using an autoencoder as a base model."""
self.logdir = logdir
self.epochs = epochs
self.measure_performance_freq = measure_performance_freq

self.optimizer = optimizer

self.metrics = metrics

# Model
self.model: tf.keras.Model = conv_autoencoder(
layer_spec_input_res,
Expand All @@ -82,20 +88,23 @@ def __init__(
)

# Loss
reconstruction_error = ashpy.losses.ClassifierLoss(
self.reconstruction_error = ashpy.losses.ClassifierLoss(
tf.keras.losses.MeanSquaredError()
)

# Trainer
self.trainer: ClassifierTrainer
self.build_trainer()

def build_trainer(self):
self.trainer = ClassifierTrainer(
model=self.model,
optimizer=optimizer,
loss=reconstruction_error,
logdir=str(logdir),
epochs=epochs,
metrics=metrics,
optimizer=self.optimizer,
loss=self.reconstruction_error,
logdir=str(self.logdir),
epochs=self.epochs,
metrics=self.metrics,
)
self.metrics = metrics

def __call__(self) -> bool:
self.trainer(
Expand Down Expand Up @@ -134,6 +143,11 @@ def __init__(
discriminator=None,
):
"""Fake training loop implementation."""
self.generator_loss = generator_loss
self.discriminator_loss = discriminator_loss
self.epochs = epochs
self.logdir = logdir

self.measure_performance_freq = measure_performance_freq

# test parameters
Expand Down Expand Up @@ -163,18 +177,8 @@ def __init__(
self.discriminator = discriminator

# Trainer
self.trainer = AdversarialTrainer(
generator=generator,
discriminator=discriminator,
generator_optimizer=tf.optimizers.Adam(1e-4),
discriminator_optimizer=tf.optimizers.Adam(1e-4),
generator_loss=generator_loss,
discriminator_loss=discriminator_loss,
epochs=epochs,
metrics=metrics,
callbacks=callbacks,
logdir=logdir,
)
self.trainer: AdversarialTrainer
self.build_trainer()

self.dataset = fake_adversarial_dataset(
image_resolution=image_resolution,
Expand All @@ -190,3 +194,17 @@ def __call__(self) -> bool:
self.dataset, measure_performance_freq=self.measure_performance_freq
)
return True

def build_trainer(self):
self.trainer = AdversarialTrainer(
generator=self.generator,
discriminator=self.discriminator,
generator_optimizer=tf.optimizers.Adam(1e-4),
discriminator_optimizer=tf.optimizers.Adam(1e-4),
generator_loss=self.generator_loss,
discriminator_loss=self.discriminator_loss,
epochs=self.epochs,
metrics=self.metrics,
callbacks=self.callbacks,
logdir=self.logdir,
)

0 comments on commit d9aff79

Please sign in to comment.