Skip to content

Commit

Permalink
Make Restorers work without checkpoint_map.json (#38)
Browse files Browse the repository at this point in the history
Make Restorers work without checkpoint_map.json
  • Loading branch information
mr-ubik committed Jan 29, 2020
2 parents f437294 + c9b1da3 commit 5e8fe9f
Show file tree
Hide file tree
Showing 21 changed files with 442 additions and 377 deletions.
11 changes: 6 additions & 5 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import operator
import os
import shutil
from pathlib import Path

import pytest
import tensorflow # pylint: disable=import-error
Expand Down Expand Up @@ -49,19 +50,19 @@ def add_common_namespaces(doctest_namespace):
@pytest.fixture(scope="function")
def save_dir():
"""Add the save_dir parameter to tests."""
m_save_dir = "testlog/savedir"
m_save_dir = Path("testlog/savedir")

# Clean before
if os.path.exists(m_save_dir):
if m_save_dir.exists():
shutil.rmtree(m_save_dir)
assert not os.path.exists(m_save_dir)
assert not m_save_dir.exists()

yield m_save_dir

# teardown
if os.path.exists(m_save_dir):
if m_save_dir.exists():
shutil.rmtree(m_save_dir)
assert not os.path.exists(m_save_dir)
assert not m_save_dir.exists()


# ------------------------------------------------------------------------------------
Expand Down
10 changes: 5 additions & 5 deletions examples/gans/pix2pix_facades.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Input Pipeline taken from: https://www.tensorflow.org/beta/tutorials/generative/pix2pix
"""
import os
from pathlib import Path

import tensorflow as tf

Expand All @@ -34,7 +34,7 @@
_URL = "https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz"

PATH_TO_ZIP = tf.keras.utils.get_file("facades.tar.gz", origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(PATH_TO_ZIP), "facades/")
PATH = Path(PATH_TO_ZIP).parent / "facades"

BUFFER_SIZE = 100
BATCH_SIZE = 1
Expand Down Expand Up @@ -172,10 +172,10 @@ def main(
)

metrics = []
logdir = f'{"log"}/{dataset_name}/run2'
logdir = Path("log") / dataset_name / "run2"

if not os.path.exists(logdir):
os.makedirs(logdir)
if not logdir.exists():
logdir.mkdir(parents=True)

trainer = AdversarialTrainer(
generator=generator,
Expand Down
4 changes: 2 additions & 2 deletions examples/gans/pix2pix_facades_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def main(
metrics = []
logdir = f'{"log"}/{dataset_name}/run_multi'

if not os.path.exists(logdir):
os.makedirs(logdir)
if not logdir.exists():
logdir.mkdir(parents=True)

trainer = AdversarialTrainer(
generator=generator,
Expand Down
6 changes: 3 additions & 3 deletions src/ashpy/callbacks/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class LogImageGANCallback(CounterCallback):
import shutil
import operator
import os
from pathlib import Path
generator = models.gans.ConvGenerator(
layer_spec_input_res=(7, 7),
Expand Down Expand Up @@ -69,7 +69,7 @@ class LogImageGANCallback(CounterCallback):
# Trainer
epochs = 2
logdir = "testlog/callbacks"
logdir = Path("testlog/callbacks")
callbacks = [callbacks.LogImageGANCallback()]
trainer = trainers.gan.AdversarialTrainer(
generator=generator,
Expand Down Expand Up @@ -100,7 +100,7 @@ class LogImageGANCallback(CounterCallback):
trainer(dataset)
shutil.rmtree(logdir)
assert not os.path.exists(logdir)
assert not logdir.exists()
trainer._global_step.assign_add(500)
Expand Down
33 changes: 19 additions & 14 deletions src/ashpy/callbacks/save_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.

"""Save weights callback."""
import os
import shutil
from collections import deque
from enum import Enum, Flag, auto
from pathlib import Path
from typing import List

import tensorflow as tf
Expand Down Expand Up @@ -47,31 +47,32 @@ def name(self) -> str:
return "saved-model-and-weights"

@staticmethod
def _initialize_dirs(save_dir, save_format, save_sub_format):
def _initialize_dirs(save_dir: Path, save_format, save_sub_format) -> Path:
"""Initialize the directory for this save_format and sub-format."""
save_dir = os.path.join(save_dir, save_format.name())
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_dir = save_dir / save_format.name()
if not save_dir.exists():
save_dir.mkdir(parents=True)

save_dir = (
save_dir
if save_sub_format == SaveSubFormat.TF
else os.path.join(save_dir, save_format.name())
else save_dir / save_format.name()
)

return save_dir

def save(
self,
model: tf.keras.models.Model,
save_dir: str,
save_dir: Path,
save_sub_format: SaveSubFormat = SaveSubFormat.TF,
) -> None:
"""
Save the model using the correct format and sub-format.
Args:
model (:py:class:`tf.keras.models.Model`): model to Save.
save_dir (str): path of the file in which to save the model.
save_dir (:class:`pathlib.Path`): path of the file in which to save the model.
save_sub_format (:py:class:`ashpy.callbacks.save_callback.SaveSubFormat`): sub-format
of the save operation.
Expand All @@ -81,14 +82,18 @@ def save(
save_dir = self._initialize_dirs(
save_dir, SaveFormat.WEIGHTS, save_sub_format
)
model.save_weights(save_dir, save_format=save_sub_format.value)
# NOTE: Keras (TF 2.1.0) checks for h5 file using endswith attribute.
# Explicit conversion to strings is required
model.save_weights(str(save_dir), save_format=save_sub_format.value)

if SaveFormat.MODEL & self:

save_dir = self._initialize_dirs(
save_dir, SaveFormat.MODEL, save_sub_format
)
model.save(save_dir, save_format=save_sub_format.value)
# NOTE: TensorFlow 2.1.0 wanth either binary or unicod string.
# Explicit conversion to strings is required
model.save(str(save_dir), save_format=save_sub_format.value)

if not (SaveFormat.MODEL & self) | (SaveFormat.WEIGHTS & self):
raise NotImplementedError(
Expand Down Expand Up @@ -138,7 +143,7 @@ class SaveCallback(CounterCallback):

def __init__(
self,
save_dir: str,
save_dir: Path,
models: List[tf.keras.models.Model],
event: Event = Event.ON_EPOCH_END,
event_freq: int = 1,
Expand Down Expand Up @@ -248,10 +253,10 @@ def _save_weights_fn(self, step: int):
)

# Create the correct directory name
save_dir_i = os.path.join(self._save_dir, f"model-{i}-step-{step}")
save_dir_i = self._save_dir / f"model-{i}-step-{step}"

if not os.path.exists(save_dir_i):
os.makedirs(save_dir_i)
if not save_dir_i.exists():
save_dir_i.mkdir(parents=True)

# Add to the history
self._save_path_histories[i].append(save_dir_i)
Expand Down
8 changes: 5 additions & 3 deletions src/ashpy/metrics/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from __future__ import annotations

import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Union

import tensorflow as tf # pylint: disable=import-error
Expand All @@ -28,6 +28,8 @@

TPRocessingPredictions = Dict[str, Union[Callable, Dict[str, Any]]]

__ALL__ = ["ClassifierLoss", "ClassifierMetric"]


class ClassifierLoss(Metric):
"""A handy way to measure the classification loss."""
Expand All @@ -36,7 +38,7 @@ def __init__(
self,
name: str = "loss",
model_selection_operator: Callable = None,
logdir: str = os.path.join(os.getcwd(), "log"),
logdir: Union[Path, str] = Path().cwd() / "log",
) -> None:
"""
Initialize the Metric.
Expand Down Expand Up @@ -89,7 +91,7 @@ def __init__(
self,
metric: tf.keras.metrics.Metric,
model_selection_operator: Callable = None,
logdir: str = os.path.join(os.getcwd(), "log"),
logdir: Union[Path, str] = Path().cwd() / "log",
processing_predictions=None,
) -> None:
"""
Expand Down
25 changes: 17 additions & 8 deletions src/ashpy/metrics/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import operator
import os
import types
from typing import TYPE_CHECKING, Callable
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Union

import tensorflow as tf
import tensorflow_hub as hub
Expand All @@ -31,6 +32,14 @@
GANEncoderContext,
)

__ALL__ = [
"DiscriminatorLoss",
"EncoderLoss",
"EncodingAccuracy",
"GeneratorLoss",
"InceptionScore",
]


class DiscriminatorLoss(Metric):
"""The Discriminator loss value."""
Expand All @@ -39,7 +48,7 @@ def __init__(
self,
name: str = "d_loss",
model_selection_operator: Callable = None,
logdir: str = os.path.join(os.getcwd(), "log"),
logdir: Union[Path, str] = Path().cwd() / "log",
) -> None:
"""
Initialize the Metric.
Expand Down Expand Up @@ -104,7 +113,7 @@ def __init__(
self,
name: str = "g_loss",
model_selection_operator: Callable = None,
logdir: str = os.path.join(os.getcwd(), "log"),
logdir: Union[Path, str] = Path().cwd() / "log",
):
"""
Initialize the Metric.
Expand Down Expand Up @@ -168,7 +177,7 @@ def __init__(
self,
name: str = "e_loss",
model_selection_operator: Callable = None,
logdir: str = os.path.join(os.getcwd(), "log"),
logdir: Union[Path, str] = Path().cwd() / "log",
) -> None:
"""
Initialize the Metric.
Expand Down Expand Up @@ -241,7 +250,7 @@ def __init__(
inception: tf.keras.Model,
name: str = "inception_score",
model_selection_operator=operator.gt,
logdir=os.path.join(os.getcwd(), "log"),
logdir=Path().cwd() / "log",
):
"""
Initialize the Metric.
Expand Down Expand Up @@ -355,7 +364,7 @@ def get_or_train_inception(
from_logits=True
),
optimizer: tf.keras.optimizers.Adam = tf.keras.optimizers.Adam(1e-5),
logdir: str = os.path.join(os.getcwd(), "log"),
logdir: Union[Path, str] = Path().cwd() / "log",
) -> tf.keras.Model:
"""
Restore or train (and save) the Inception model.
Expand Down Expand Up @@ -396,7 +405,7 @@ def get_or_train_inception(
ckpt.objects.extend([model, step])
logdir = logdir
manager = tf.train.CheckpointManager(
ckpt, os.path.join(logdir, "inception", name), max_to_keep=1
ckpt, logdir / "inception", name, max_to_keep=1
)

if manager.latest_checkpoint:
Expand Down Expand Up @@ -428,7 +437,7 @@ def __init__(
classifier: tf.keras.Model,
name: str = "encoding_accuracy",
model_selection_operator: Callable = None,
logdir=os.path.join(os.getcwd(), "log"),
logdir=Path().cwd() / "log",
) -> None:
"""
Measure the Generator and Encoder performance together.
Expand Down

0 comments on commit 5e8fe9f

Please sign in to comment.