Skip to content

Commit

Permalink
Merge pull request #35 from mr-ubik/issue-28-model-selection-with-dec…
Browse files Browse the repository at this point in the history
…reasing-metrics

Fix metrics names collision
  • Loading branch information
galeone committed Jan 22, 2020
2 parents b5d7215 + 259ae5d commit 5429172
Show file tree
Hide file tree
Showing 13 changed files with 172 additions and 56 deletions.
4 changes: 2 additions & 2 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Fixes # (issue)

Please delete options that are not relevant.

- [ ] Bug fix (i.e., a non-breaking change which fixes an issue)
- [ ] Bugfix (i.e., a non-breaking change which fixes an issue)
- [ ] New feature (i.e., a non-breaking change which adds functionality)
- [ ] Breaking change (i.e., a fix or feature that would cause existing functionality to not work as expected)
- [ ] This change requires a documentation update
Expand All @@ -46,7 +46,7 @@ List any dependencies that are required for this change.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have commented on my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
Expand Down
6 changes: 4 additions & 2 deletions src/ashpy/metrics/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ class ClassifierLoss(Metric):

def __init__(
self,
name: str = "loss",
model_selection_operator: Callable = None,
logdir: str = os.path.join(os.getcwd(), "log"),
) -> None:
"""
Initialize the Metric.
Args:
name (str): Name of the metric.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
used by the `update_state`.
Expand All @@ -54,8 +56,8 @@ def __init__(
"""
super().__init__(
name="loss",
metric=tf.metrics.Mean(name="loss", dtype=tf.float32),
name=name,
metric=tf.metrics.Mean(name=name, dtype=tf.float32),
model_selection_operator=model_selection_operator,
logdir=logdir,
)
Expand Down
28 changes: 19 additions & 9 deletions src/ashpy/metrics/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ class DiscriminatorLoss(Metric):

def __init__(
self,
name: str = "d_loss",
model_selection_operator: Callable = None,
logdir: str = os.path.join(os.getcwd(), "log"),
) -> None:
"""
Initialize the Metric.
Args:
name (str): Name of the metric.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
used by the `update_state`.
Expand All @@ -57,8 +59,8 @@ def __init__(
"""
super().__init__(
name="d_loss",
metric=tf.metrics.Mean(name="d_loss", dtype=tf.float32),
name=name,
metric=tf.metrics.Mean(name=name, dtype=tf.float32),
model_selection_operator=model_selection_operator,
logdir=logdir,
)
Expand Down Expand Up @@ -100,13 +102,15 @@ class GeneratorLoss(Metric):

def __init__(
self,
name: str = "g_loss",
model_selection_operator: Callable = None,
logdir: str = os.path.join(os.getcwd(), "log"),
):
"""
Initialize the Metric.
Args:
name (str): Name of the metric.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
used by the `update_state`.
Expand All @@ -120,8 +124,8 @@ def __init__(
"""
super().__init__(
name="g_loss",
metric=tf.metrics.Mean(name="g_loss", dtype=tf.float32),
name=name,
metric=tf.metrics.Mean(name=name, dtype=tf.float32),
model_selection_operator=model_selection_operator,
logdir=logdir,
)
Expand Down Expand Up @@ -162,13 +166,15 @@ class EncoderLoss(Metric):

def __init__(
self,
name: str = "e_loss",
model_selection_operator: Callable = None,
logdir: str = os.path.join(os.getcwd(), "log"),
) -> None:
"""
Initialize the Metric.
Args:
name (str): Name of the metric.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
used by the `update_state`.
Expand All @@ -182,8 +188,8 @@ def __init__(
"""
super().__init__(
name="e_loss",
metric=tf.metrics.Mean(name="e_loss", dtype=tf.float32),
name=name,
metric=tf.metrics.Mean(name=name, dtype=tf.float32),
model_selection_operator=model_selection_operator,
logdir=logdir,
)
Expand Down Expand Up @@ -233,6 +239,7 @@ class InceptionScore(Metric):
def __init__(
self,
inception: tf.keras.Model,
name: str = "inception_score",
model_selection_operator=operator.gt,
logdir=os.path.join(os.getcwd(), "log"),
):
Expand All @@ -241,6 +248,7 @@ def __init__(
Args:
inception (:py:class:`tf.keras.Model`): Keras Inception model.
name (str): Name of the metric.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
used by the `update_state`.
Expand All @@ -254,8 +262,8 @@ def __init__(
"""
super().__init__(
name="inception_score",
metric=tf.metrics.Mean("inception_score"),
name=name,
metric=tf.metrics.Mean(name),
model_selection_operator=model_selection_operator,
logdir=logdir,
)
Expand Down Expand Up @@ -418,6 +426,7 @@ class EncodingAccuracy(ClassifierMetric):
def __init__(
self,
classifier: tf.keras.Model,
name: str = "encoding_accuracy",
model_selection_operator: Callable = None,
logdir=os.path.join(os.getcwd(), "log"),
) -> None:
Expand All @@ -430,6 +439,7 @@ def __init__(
Args:
classifier (:py:class:`tf.keras.Model`): Keras Model to use as a Classifier to
measure the accuracy. Generally assumed to be the Inception Model.
name (str): Name of the metric.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
used by the `update_state`.
Expand All @@ -443,7 +453,7 @@ def __init__(
"""
super().__init__(
metric=tf.metrics.Accuracy("encoding_accuracy"),
metric=tf.metrics.Accuracy(name),
model_selection_operator=model_selection_operator,
logdir=logdir,
)
Expand Down
22 changes: 15 additions & 7 deletions src/ashpy/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
Initialize the Metric object.
Args:
name (str): The name of the metric.
name (str): Name of the metric.
metric (:py:class:`tf.keras.metrics.Metric`): The Keras metric to use.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
Expand Down Expand Up @@ -81,32 +81,40 @@ def model_selection(
"""
current_value = self.result()
previous_value = float(self.json_read(self.best_model_sel_file)[self._name])
previous_value = float(
self.json_read(self.best_model_sel_file)[self._name.replace("/", "_")]
)
# Model selection is done ONLY if an operator was passed at __init__
if self._model_selection_operator and self._model_selection_operator(
current_value, previous_value
):
tf.print(
f"{self.name}: validation value: {previous_value}{current_value}"
f"{self._name.replace('/', '_')}: validation value: {previous_value}{current_value}"
)
Metric.json_write(
self.best_model_sel_file,
{self._name: str(current_value), "step": int(global_step.numpy())},
{
self._name.replace("/", "_"): str(current_value),
"step": int(global_step.numpy()),
},
)
manager = tf.train.CheckpointManager(
checkpoint, os.path.join(self.best_folder, "ckpts"), max_to_keep=1
)
manager.save()

def _update_logdir(self):
if not self._model_selection_operator:
pass
# write the initial value of the best metric
if not os.path.exists(self.best_model_sel_file):
os.makedirs(os.path.dirname(self.best_model_sel_file))
initial_value = (
np.inf if self._model_selection_operator is operator.lt else -np.inf
)
self.json_write(
self.best_model_sel_file, {self._name: str(initial_value), "step": 0}
self.best_model_sel_file,
{self._name.replace("/", "_"): str(initial_value), "step": 0},
)

@property
Expand Down Expand Up @@ -138,12 +146,12 @@ def logdir(self, logdir) -> None:
@property
def best_folder(self) -> str:
"""Retrieve the folder used to save the best model when doing model selection."""
return os.path.join(self.logdir, "best", self._name)
return os.path.join(self.logdir, "best", self._name.replace("/", "_"))

@property
def best_model_sel_file(self) -> str:
"""Retrieve the path to JSON file containing the measured performance of the best model."""
return os.path.join(self.best_folder, self._name + ".json")
return os.path.join(self.best_folder, self._name.replace("/", "_") + ".json")

@staticmethod
def json_read(filename: str) -> Dict[str, Any]:
Expand Down
6 changes: 4 additions & 2 deletions src/ashpy/metrics/sliced_wasserstein_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class SlicedWassersteinDistance(Metric):

def __init__(
self,
name: str = "SWD",
model_selection_operator: Callable = operator.lt,
logdir: str = os.path.join(os.getcwd(), "log"),
resolution: int = 128,
Expand All @@ -106,6 +107,7 @@ def __init__(
Initialize the Metric.
Args:
name (str): Name of the metric.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
used by the `update_state`.
Expand All @@ -126,8 +128,8 @@ def __init__(
"""
super().__init__(
name="SWD",
metric=tf.metrics.Mean(name="SWD", dtype=tf.float32),
name=name,
metric=tf.metrics.Mean(name=name, dtype=tf.float32),
model_selection_operator=model_selection_operator,
logdir=logdir,
)
Expand Down
14 changes: 8 additions & 6 deletions src/ashpy/metrics/ssim_multiscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,21 @@ class SSIM_Multiscale(Metric): # pylint: disable=invalid-name

def __init__(
self,
name: str = "SSIM_Multiscale",
model_selection_operator: Callable = operator.lt,
logdir: str = os.path.join(os.getcwd(), "log"),
max_val: float = 2.0,
power_factors=None,
filter_size: int = 11,
filter_sigma: int = 1.5,
k1: int = 0.01,
k2: int = 0.03,
k1: float = 0.01,
k2: float = 0.03,
) -> None:
"""
Initialize the Metric.
Args:
name (str): Name of the metric.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
used by the `update_state`.
Expand All @@ -74,14 +76,14 @@ def __init__(
0.1333), which are the values obtained in the original paper.
filter_size (int): Default value 11 (size of gaussian filter).
filter_sigma (int): Default value 1.5 (width of gaussian filter).
k1 (int): Default value 0.01.
k2 (int): Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
k1 (float): Default value 0.01.
k2 (float): Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
it would be better if we take the values in range of 0< K2 <0.4).
"""
super().__init__(
name="SSIM_Multiscale",
metric=tf.metrics.Mean(name="SSIM_Multiscale", dtype=tf.float32),
name=name,
metric=tf.metrics.Mean(name=name, dtype=tf.float32),
model_selection_operator=model_selection_operator,
logdir=logdir,
)
Expand Down
3 changes: 2 additions & 1 deletion src/ashpy/trainers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,14 @@ def toy_dataset():
self._loss = loss
self._loss.reduction = tf.keras.losses.Reduction.NONE

self._avg_loss = ClassifierLoss()
self._avg_loss = ClassifierLoss(name="ashpy/avg_loss")
if metrics:
metrics.append(self._avg_loss)
else:
metrics = [self._avg_loss]

super()._update_metrics(metrics)
super()._validate_metrics()

self._checkpoint.objects.extend([self._optimizer, self._model])
self._restore_or_init()
Expand Down
14 changes: 5 additions & 9 deletions src/ashpy/trainers/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,16 @@ def __init__(
self._discriminator_loss.reduction = tf.losses.Reduction.NONE

losses_metrics = [
DiscriminatorLoss(logdir=logdir),
GeneratorLoss(logdir=logdir),
DiscriminatorLoss(name="ashpy/d_loss", logdir=logdir),
GeneratorLoss(name="ashpy/g_loss", logdir=logdir),
]
if metrics:
metrics.extend(losses_metrics)
else:
metrics = losses_metrics

super()._update_metrics(metrics)
super()._validate_metrics()

self._generator_optimizer = generator_optimizer
self._discriminator_optimizer = discriminator_optimizer
Expand Down Expand Up @@ -425,12 +426,7 @@ def real_gen():
if os.path.exists(logdir):
shutil.rmtree(logdir)
metrics = [
metrics.gan.EncodingAccuracy(
classifier,
# model_selection_operator=operator.gt,
)
]
metrics = [metrics.gan.EncodingAccuracy(classifier)]
trainer = trainers.gan.EncoderTrainer(
generator=generator,
Expand Down Expand Up @@ -526,7 +522,7 @@ def __init__(
"""
if not metrics:
metrics = []
metrics.append(EncoderLoss(logdir=logdir))
metrics.append(EncoderLoss(name="ashpy/e_loss", logdir=logdir))
super().__init__(
generator=generator,
discriminator=discriminator,
Expand Down
5 changes: 5 additions & 0 deletions src/ashpy/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def context(self, _context: Context):
def _validate_metrics(self):
"""Check if every metric is an :py:class:`ashpy.metrics.Metric`."""
validate_objects(self._metrics, Metric)
buffer = []
for metric in self._metrics:
if metric._name in buffer:
raise ValueError("Metric should have unique names.")
buffer.append(metric._name)

def _validate_callbacks(self):
"""Check if every callback is an :py:class:`ashpy.callbacks.Callback`."""
Expand Down
3 changes: 2 additions & 1 deletion tests/callbacks/test_custom_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class MCallback(Callback):
Check the number of times the on_event is triggered.
"""

def __init__(self, event):
def __init__(self, event) -> None:
"""Initialize Callback."""
super(MCallback, self).__init__()
self._event = event
self.counter = 0
Expand Down

0 comments on commit 5429172

Please sign in to comment.