Skip to content

Commit

Permalink
Merge pull request #13 from zurutech/code-quality
Browse files Browse the repository at this point in the history
Code quality + pix2pix multi-gpu example
  • Loading branch information
galeone committed Jul 19, 2019
2 parents 5528e16 + d0eb193 commit 188f31e
Show file tree
Hide file tree
Showing 18 changed files with 351 additions and 79 deletions.
5 changes: 4 additions & 1 deletion ashpy/contexts/base_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@


class BaseContext:
""":py:class:`ashpy.contexts.base_context.BaseContext` provide an interface for all contexts to inherit from."""
"""
:py:class:`ashpy.contexts.base_context.BaseContext` provide an interface
for all contexts to inherit from.
"""

def __init__(
self,
Expand Down
5 changes: 4 additions & 1 deletion ashpy/contexts/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@


class ClassifierContext(BaseContext):
r""":py:class:`ashpy.contexts.classifier.ClassifierContext` provide the standard functions to test a classifier."""
r"""
:py:class:`ashpy.contexts.classifier.ClassifierContext` provide
the standard functions to test a classifier.
"""

def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion ashpy/contexts/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def discriminator_loss(self) -> Optional[Executor]:


class GANEncoderContext(GANContext):
r""":py:class:`ashpy.contexts.gan.GANEncoderContext` measure the specified metrics on the GAN."""
r"""
:py:class:`ashpy.contexts.gan.GANEncoderContext` measure the specified metrics on the GAN.
"""

def __init__(
self,
Expand Down
8 changes: 5 additions & 3 deletions ashpy/metrics/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,16 @@ def inception_score(self, images, splits=10):
),
:,
]
kl = part * (
kl_divergence = part * (
tf.math.log(part)
- tf.math.log(
tf.expand_dims(tf.math.reduce_mean(part, axis=0), axis=[0])
)
)
kl = tf.math.reduce_mean(tf.math.reduce_sum(kl, axis=1))
scores.append(tf.math.exp(kl))
kl_divergence = tf.math.reduce_mean(
tf.math.reduce_sum(kl_divergence, axis=1)
)
scores.append(tf.math.exp(kl_divergence))
return tf.math.reduce_mean(scores).numpy(), tf.math.reduce_std(scores).numpy()

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions ashpy/models/convolutional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
unet.UNet
unet.SUNet
unet.FunctionalUNet
----
Expand Down
3 changes: 2 additions & 1 deletion ashpy/models/convolutional/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def _add_building_block(self, filters):

def _add_final_block(self, channels):
"""
Take the results of :func:`_add_building_block` and prepare them for the for the final output.
Take the results of :func:`_add_building_block` and prepare them
for the for the final output.
Args:
channels (int): Channels of the output images (1 for Grayscale, 3 for RGB).
Expand Down
34 changes: 20 additions & 14 deletions ashpy/models/convolutional/discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def __init__(
use_dropout (bool): whether to use dropout
dropout_prob (float): probability of dropout
non_linearity (:class:`tf.keras.layers.Layer`): non linearity used in the model
normalization_layer (:class:`tf.keras.layers.Layer`): normalization layer used in the model
normalization_layer (:class:`tf.keras.layers.Layer`): normalization
layer used in the model
use_attention (bool): whether to use attention
"""
self.use_attention = use_attention
Expand Down Expand Up @@ -241,12 +242,14 @@ def __init__(
input_res (int): input resolution
min_res (int): minimum resolution reached by the discriminators
kernel_size (int): kernel size of discriminators
initial_filters (int): number of initial filters in the first layer of the discriminators
initial_filters (int): number of initial filters in the first layer
of the discriminators
filters_cap (int): maximum number of filters in the discriminators
use_dropout (bool): whether to use dropout
dropout_prob (float): probability of dropout
non_linearity (:class:`tf.keras.layers.Layer`): non linearity used in discriminators
normalization_layer (:class:`tf.keras.layers.Layer`): normalization used by the discriminators
normalization_layer (:class:`tf.keras.layers.Layer`): normalization used by the
discriminators
use_attention (bool): whether to use attention
n_discriminators (int): Number of discriminators
Expand Down Expand Up @@ -284,7 +287,7 @@ def build_discriminator(self, input_res) -> Discriminator:
Returns:
A Discriminator (PatchDiscriminator).
"""
d = PatchDiscriminator(
return PatchDiscriminator(
input_res=input_res,
min_res=self.min_res,
kernel_size=self.kernel_size,
Expand All @@ -296,7 +299,6 @@ def build_discriminator(self, input_res) -> Discriminator:
use_attention=self.use_attention,
normalization_layer=self.normalization_layer,
)
return d

def call(
self, inputs: Union[List, tf.Tensor], training=True, return_features=False
Expand All @@ -308,28 +310,32 @@ def call(
return_features (bool): whether to return features or not
Returns:
([:py:class:`tf.Tensor`]): A List of Tensors containing the value of D_i for each input.
([:py:class:`tf.Tensor`]): A List of features for each discriminator if `return_features`
([:py:class:`tf.Tensor`]): A List of Tensors containing the
value of D_i for each input
([:py:class:`tf.Tensor`]): A List of features for each discriminator if
`return_features`
"""
is_conditioned = isinstance(inputs, list)

if is_conditioned:
xs, condition = (
fake_or_real, condition = (
inputs
) # inputs is a tuple containing the generated images and the conditions
else:
xs = inputs
fake_or_real = inputs
condition = None
outs = []
features = []

x_i = xs
fake_or_real_i = fake_or_real
condition_i = condition
for i, d in enumerate(self.discriminators):
for i, discriminator in enumerate(self.discriminators):
# compute value of the i-th discriminator
out, feat = d(
[x_i, condition_i] if condition_i is not None else x_i,
out, feat = discriminator(
[fake_or_real_i, condition_i]
if condition_i is not None
else fake_or_real_i,
training=training,
return_features=True,
)
Expand All @@ -339,7 +345,7 @@ def call(
features.extend(feat)
# reduce input size
if i != len(self.discriminators) - 1:
x_i = self.subsampling(x_i)
fake_or_real_i = self.subsampling(fake_or_real_i)
condition_i = (
self.subsampling(condition_i) if condition_i is not None else None
)
Expand Down
3 changes: 2 additions & 1 deletion ashpy/models/convolutional/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def _add_building_block(self, filters):

def _add_final_block(self, output_shape):
"""
Take the results of :func:`_add_building_block` and prepare them for the for the final output.
Take the results of :func:`_add_building_block` and prepare them for the for
the final output.
Args:
output_shape (int): Amount of units of the last :py:obj:`tf.keras.layers.Dense`
Expand Down
1 change: 0 additions & 1 deletion ashpy/models/convolutional/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import inspect

import numpy as np
import tensorflow as tf
from tensorflow import keras # pylint: disable=no-name-in-module

__ALL__ = ["Conv2DInterface"]
Expand Down
60 changes: 41 additions & 19 deletions ashpy/models/convolutional/pix2pixhd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
Global Generator + Local Enhancer
.. [1] High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs: https://arxiv.org/abs/1711.11585
.. [1] High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs:
https://arxiv.org/abs/1711.11585
"""
import typing
Expand Down Expand Up @@ -161,22 +162,31 @@ def __init__(
padding="same",
)

# un-comment in order to export the model
# @tf.function(
# input_signature=[tf.TensorSpec(shape=[None, 512, 512, 1], dtype=tf.float32)]
# )
def call(self, inputs, training=False):
"""
LocalEnhancer call.
Args:
inputs (:py:class:`tf.Tensor`): Input Tensors
training (bool): Whether it is training phase or not
Returns:
(:py:class:`tf.Tensor`): Image of size (input_res, input_res, channels)
as specified in the init call
"""
downsampled = self.downsample(inputs)

# call the global generator
global_generator_output, global_generator_features = self.global_generator(
downsampled
)
_, global_generator_features = self.global_generator(downsampled)

# first downsample
x = inputs
for layer in self.downsample_block:
if isinstance(layer, keras.layers.BatchNormalization) or isinstance(
layer, keras.layers.Dropout
if isinstance(
layer, (keras.layers.BatchNormalization, keras.layers.Dropout)
):
x = layer(x, trainig=training)
else:
Expand All @@ -190,8 +200,8 @@ def call(self, inputs, training=False):

# upsample
for layer in self.upsample_block:
if isinstance(layer, keras.layers.BatchNormalization) or isinstance(
layer, keras.layers.Dropout
if isinstance(
layer, (keras.layers.BatchNormalization, keras.layers.Dropout)
):
x = layer(x, training=training)
else:
Expand Down Expand Up @@ -227,10 +237,12 @@ def __init__(
Args:
filters (int): initial filters (same as the output filters)
normalization_layer (:class:`tf.keras.layers.Layer`): layer of normalization used by the residual block
normalization_layer (:class:`tf.keras.layers.Layer`): layer of normalization
used by the residual block
non_linearity (:class:`tf.keras.layers.Layer`): non linearity used in the resnet block
kernel_size (int): kernel size used in the resnet block
num_blocks (int): number of blocks, each block is composed by conv, normalization and non linearity
num_blocks (int): number of blocks, each block is composed by conv,
normalization and non linearity
"""
super(ResNetBlock, self).__init__()
self.model_layers = []
Expand Down Expand Up @@ -297,12 +309,14 @@ def __init__(
initial_filters (int): number of initial filters
filters_cap (int): maximum number of filters
channels (int): output channels
normalization_layer (:class:`tf.keras.layers.Layer`): normalization layer used by the global generator
can be Instance Norm, Layer Norm, Batch Norm
non_linearity (:class:`tf.keras.layers.Layer`): non linearity used in the global generator
normalization_layer (:class:`tf.keras.layers.Layer`): normalization layer used
by the global generator, can be Instance Norm, Layer Norm, Batch Norm
non_linearity (:class:`tf.keras.layers.Layer`): non linearity
used in the global generator
num_resnet_blocks (int): number of resnet blocks
kernel_size_resnet (int): kernel size used in resnets conv layers
kernel_size_front_back (int): kernel size used by the convolutional front end and backend
kernel_size_front_back (int): kernel size used by the convolutional
frontend and backend
num_internal_resnet_blocks (int): number of blocks used by internal resnet
"""
super().__init__()
Expand Down Expand Up @@ -339,7 +353,7 @@ def __init__(

# ResNet Block
self.resnet_blocks = []
for block in range(num_resnet_blocks):
for _ in range(num_resnet_blocks):
self.resnet_blocks.append(
ResNetBlock(
filters,
Expand Down Expand Up @@ -383,14 +397,22 @@ def __init__(
self.model_layers.append(self.last_layer)

def call(self, inputs, training=True):
"""
Call of the Pix2Pix HD model
Args:
inputs: input tensor(s)
training: If True training phase
Returns:
:py:class:`Tuple`: Generated images.
"""
out = inputs
prev = inputs
for layer in self.model_layers:
prev = out
if (
isinstance(layer, ResNetBlock)
or isinstance(layer, keras.layers.BatchNormalization)
or isinstance(layer, keras.layers.Dropout)
if isinstance(
layer,
(ResNetBlock, keras.layers.BatchNormalization, keras.layers.Dropout),
):
out = layer(prev, training=training)
else:
Expand Down
23 changes: 12 additions & 11 deletions ashpy/models/convolutional/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ashpy.layers import Attention, InstanceNormalization
from ashpy.models.convolutional.interfaces import Conv2DInterface

__ALL__ = ["UNet", "SUNet"]
__ALL__ = ["UNet", "SUNet", "FunctionalUNet"]


class UNet(Conv2DInterface):
Expand Down Expand Up @@ -53,7 +53,8 @@ class UNet(Conv2DInterface):
(1, 512, 512, 3)
True
.. [1] Image-to-Image Translation with Conditional Adversarial Nets https://arxiv.org/abs/1611.04076
.. [1] Image-to-Image Translation with Conditional Adversarial Nets
https://arxiv.org/abs/1611.04076
"""

Expand All @@ -71,7 +72,7 @@ def __init__(
encoder_non_linearity: typing.Type[keras.layers.Layer] = keras.layers.LeakyReLU,
decoder_non_linearity: typing.Type[keras.layers.Layer] = keras.layers.ReLU,
normalization_layer: typing.Type[keras.layers.Layer] = InstanceNormalization,
last_activation: keras.activations = keras.activations.tanh, # tanh or softmax (for semantic images)
last_activation: keras.activations = keras.activations.tanh,
use_attention: bool = False,
):
"""
Expand All @@ -88,7 +89,7 @@ def __init__(
dropout_prob: probability of dropout
encoder_non_linearity: non linearity of encoder
decoder_non_linearity: non linearity of decoder
last_activation: last activation function
last_activation: last activation function, tanh or softmax (for semantic images)
use_attention: whether to use attention
"""
super().__init__()
Expand Down Expand Up @@ -121,7 +122,7 @@ def __init__(
decoder_layer_spec.insert(0, filters)
block = self.get_encoder_block(
filters,
use_bn=(i != 0 and i != (len(encoder_layers_spec) - 1)),
use_bn=(i not in (0, len(encoder_layers_spec) - 1)),
use_attention=i == 2,
)
self.encoder_layers.append(block)
Expand Down Expand Up @@ -239,8 +240,8 @@ def call(self, inputs, training=False):

for block in self.encoder_layers:
for layer in block:
if isinstance(layer, keras.layers.BatchNormalization) or isinstance(
layer, keras.layers.Dropout
if isinstance(
layer, (keras.layers.BatchNormalization, keras.layers.Dropout)
):
x = layer(x, training=training)
else:
Expand All @@ -252,8 +253,8 @@ def call(self, inputs, training=False):

for i, block in enumerate(self.decoder_layers):
for layer in block:
if isinstance(layer, keras.layers.BatchNormalization) or isinstance(
layer, keras.layers.Dropout
if isinstance(
layer, (keras.layers.BatchNormalization, keras.layers.Dropout)
):
x = layer(x, training=training)
else:
Expand Down Expand Up @@ -319,7 +320,7 @@ def FUNet(
use_attention=False,
):
"""
Functional UNET
Functional UNET Implementation
"""
# ########### Encoder creation
encoder_layers_spec = Conv2DInterface._get_layer_spec(
Expand Down Expand Up @@ -378,7 +379,7 @@ def get_block(
kernel_size,
filters,
conv_layer=keras.layers.Conv2D,
use_bn=(i != 0 and i != (len(encoder_layers_spec) - 1)),
use_bn=(i not in (0, len(encoder_layers_spec) - 1)),
use_dropout=use_dropout_encoder,
non_linearity=encoder_non_linearity,
use_attention=(i == 2 and use_attention),
Expand Down

0 comments on commit 188f31e

Please sign in to comment.