Skip to content

Commit e8889ed

Browse files
chaenecopybara-github
authored andcommitted
Adds a dictionary with the name to layer mapping for the custom layers.
PiperOrigin-RevId: 386117006
1 parent 0222ca5 commit e8889ed

File tree

6 files changed

+1344
-0
lines changed

6 files changed

+1344
-0
lines changed
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
# Copyright 2020 The TensorFlow Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Network architectures from the progressive GAN paper.
15+
16+
Implemented according to the paper "Progressive growing of GANs for Improved
17+
Quality, Stability, and Variation"
18+
https://arxiv.org/abs/1710.10196
19+
20+
Intermediate outputs and inputs are supported for implementation of "MSG-GAN:
21+
Multi-Scale Gradient GAN for Stable Image Synthesis"
22+
https://arxiv.org/abs/1903.06048
23+
24+
The implementations are done using Keras models with the Functional API. Only a
25+
subset of the architectures presented in the papers are implemented and
26+
particularly progressive growing is not supported.
27+
"""
28+
29+
import math
30+
from typing import Callable, Optional, Sequence, Union
31+
32+
import tensorflow as tf
33+
import tensorflow_addons.layers.normalizations as tfa_normalizations
34+
35+
from tensorflow_graphics.projects.gan import keras_layers
36+
37+
_InitializerCallable = Callable[[tf.Tensor, tf.dtypes.DType], tf.Tensor]
38+
_KerasInitializer = Union[_InitializerCallable, str]
39+
40+
41+
def to_rgb(input_tensor: tf.Tensor,
42+
kernel_initializer: _KerasInitializer,
43+
name: Optional[str] = None) -> tf.Tensor:
44+
"""Converts a feature map to an rgb output.
45+
46+
Args:
47+
input_tensor: The input feature map.
48+
kernel_initializer: The kernel initializer to use.
49+
name: The name of the layer.
50+
51+
Returns:
52+
The rgb image.
53+
"""
54+
return keras_layers.FanInScaledConv2D(
55+
multiplier=1.0,
56+
filters=3,
57+
kernel_size=1,
58+
strides=1,
59+
kernel_initializer=kernel_initializer,
60+
padding='same',
61+
name=name)(
62+
input_tensor)
63+
64+
65+
def create_generator(latent_code_dimension: int = 128,
66+
upsampling_blocks_num_channels: Sequence[int] = (512, 256,
67+
128, 64),
68+
relu_leakiness: float = 0.2,
69+
kernel_initializer: Optional[_KerasInitializer] = None,
70+
use_pixel_normalization: bool = True,
71+
use_batch_normalization: bool = False,
72+
generate_intermediate_outputs: bool = False,
73+
normalize_latent_code: bool = True,
74+
name: str = 'progressive_gan_generator') -> tf.keras.Model:
75+
"""Creates a Keras model for the generator network architecture.
76+
77+
This architecture is implemented according to the paper "Progressive growing
78+
of GANs for Improved Quality, Stability, and Variation"
79+
https://arxiv.org/abs/1710.10196
80+
The intermediate outputs are optionally provided for the architecture of
81+
"MSG-GAN: Multi-Scale Gradient GAN for Stable Image Synthesis"
82+
https://arxiv.org/abs/1903.06048
83+
84+
Args:
85+
latent_code_dimension: The number of dimensions in the latent code.
86+
upsampling_blocks_num_channels: The number of channels for each upsampling
87+
block. This argument also determines how many upsampling blocks are added.
88+
relu_leakiness: Slope of the negative part of the leaky relu.
89+
kernel_initializer: Initializer of the kernel. If none TruncatedNormal is
90+
used.
91+
use_pixel_normalization: If pixel normalization layers should be inserted to
92+
the network.
93+
use_batch_normalization: If batch normalization layers should be inserted to
94+
the network.
95+
generate_intermediate_outputs: If true the model outputs a list of
96+
tf.Tensors with increasing resolution starting with the starting_size up
97+
to the final resolution output.
98+
normalize_latent_code: If true the latent code is normalized to unit length
99+
before feeding it to the network.
100+
name: The name of the Keras model.
101+
102+
Returns:
103+
The created generator keras model object.
104+
"""
105+
if kernel_initializer is None:
106+
kernel_initializer = tf.keras.initializers.TruncatedNormal(
107+
mean=0.0, stddev=1.0)
108+
109+
input_tensor = tf.keras.Input(shape=(latent_code_dimension,))
110+
if normalize_latent_code:
111+
maybe_normzlized_input_tensor = keras_layers.PixelNormalization(axis=1)(
112+
input_tensor)
113+
else:
114+
maybe_normzlized_input_tensor = input_tensor
115+
116+
tensor = keras_layers.FanInScaledDense(
117+
multiplier=math.sqrt(2.0) / 4.0,
118+
units=4 * 4 * latent_code_dimension,
119+
kernel_initializer=kernel_initializer)(
120+
maybe_normzlized_input_tensor)
121+
tensor = tf.keras.layers.Reshape(target_shape=(4, 4, latent_code_dimension))(
122+
tensor)
123+
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
124+
if use_batch_normalization:
125+
tensor = tf.keras.layers.BatchNormalization()(tensor)
126+
if use_pixel_normalization:
127+
tensor = keras_layers.PixelNormalization(axis=3)(tensor)
128+
tensor = keras_layers.FanInScaledConv2D(
129+
filters=upsampling_blocks_num_channels[0],
130+
kernel_size=3,
131+
strides=1,
132+
padding='same',
133+
kernel_initializer=kernel_initializer)(
134+
tensor)
135+
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
136+
if use_batch_normalization:
137+
tensor = tf.keras.layers.BatchNormalization()(tensor)
138+
if use_pixel_normalization:
139+
tensor = keras_layers.PixelNormalization(axis=3)(tensor)
140+
141+
outputs = []
142+
for index, channels in enumerate(upsampling_blocks_num_channels):
143+
if generate_intermediate_outputs:
144+
outputs.append(
145+
to_rgb(
146+
input_tensor=tensor,
147+
kernel_initializer=kernel_initializer,
148+
name='side_output_%d_conv' % index))
149+
tensor = keras_layers.TwoByTwoNearestNeighborUpSampling()(tensor)
150+
151+
for _ in range(2):
152+
tensor = keras_layers.FanInScaledConv2D(
153+
filters=channels,
154+
kernel_size=3,
155+
strides=1,
156+
padding='same',
157+
kernel_initializer=kernel_initializer)(
158+
tensor)
159+
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
160+
if use_batch_normalization:
161+
tensor = tf.keras.layers.BatchNormalization()(tensor)
162+
if use_pixel_normalization:
163+
tensor = keras_layers.PixelNormalization(axis=3)(tensor)
164+
165+
tensor = to_rgb(
166+
input_tensor=tensor,
167+
kernel_initializer=kernel_initializer,
168+
name='final_output')
169+
if generate_intermediate_outputs:
170+
outputs.append(tensor)
171+
172+
return tf.keras.Model(inputs=input_tensor, outputs=outputs, name=name)
173+
else:
174+
return tf.keras.Model(inputs=input_tensor, outputs=tensor, name=name)
175+
176+
177+
def create_conv_layer(use_fan_in_scaled_kernel: bool = False,
178+
multiplier: float = math.sqrt(2),
179+
**kwargs) -> tf.keras.layers.Conv2D:
180+
"""Creates a convolutional layer.
181+
182+
Args:
183+
use_fan_in_scaled_kernel: Whether to use a FanInScaledConv2D or a standard
184+
Conv2D layer.
185+
multiplier: Additional multiplier used only for FanInSclaedConv2D layer.
186+
**kwargs: Keyword arguments forwarded to the convolutional layers.
187+
188+
Returns:
189+
The created convolutional layer instance.
190+
"""
191+
if use_fan_in_scaled_kernel:
192+
return keras_layers.FanInScaledConv2D(multiplier=multiplier, **kwargs)
193+
else:
194+
return tf.keras.layers.Conv2D(**kwargs)
195+
196+
197+
def from_rgb(input_tensor: tf.Tensor,
198+
use_fan_in_scaled_kernel: bool,
199+
num_channels: int,
200+
kernel_initializer: _KerasInitializer,
201+
relu_leakiness: float,
202+
name: str = 'from_rgb') -> tf.Tensor:
203+
"""Converts a rgb input to a feature map.
204+
205+
Args:
206+
input_tensor: The input feature map.
207+
use_fan_in_scaled_kernel: If a fan in scaled kernel should be used.
208+
num_channels: The number of output channels.
209+
kernel_initializer: The kernel initializer to use.
210+
relu_leakiness: The leakiness of the ReLU.
211+
name: The name of the block.
212+
213+
Returns:
214+
The feature map.
215+
"""
216+
with tf.name_scope(name):
217+
output = create_conv_layer(
218+
use_fan_in_scaled_kernel=use_fan_in_scaled_kernel,
219+
filters=num_channels,
220+
kernel_size=1,
221+
strides=1,
222+
kernel_initializer=kernel_initializer,
223+
padding='same')(
224+
input_tensor)
225+
return tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(output)
226+
227+
228+
def create_discriminator(
229+
downsampling_blocks_num_channels: Sequence[Sequence[int]] = ((64, 128),
230+
(128, 128),
231+
(256, 256),
232+
(512, 512)),
233+
relu_leakiness: float = 0.2,
234+
kernel_initializer: Optional[_KerasInitializer] = None,
235+
use_fan_in_scaled_kernels: bool = True,
236+
use_layer_normalization: bool = False,
237+
use_intermediate_inputs: bool = False,
238+
use_antialiased_bilinear_downsampling: bool = False,
239+
name: str = 'progressive_gan_discriminator'):
240+
"""Creates a Keras model for the discriminator architecture.
241+
242+
This architecture is implemented according to the paper "Progressive growing
243+
of GANs for Improved Quality, Stability, and Variation"
244+
https://arxiv.org/abs/1710.10196
245+
The intermediate outputs can optionally be given as input for the architecture
246+
of "MSG-GAN: Multi-Scale Gradient GAN for Stable Image Synthesis"
247+
https://arxiv.org/abs/1903.06048
248+
249+
Args:
250+
downsampling_blocks_num_channels: The number of channels in the downsampling
251+
blocks for each block the number of channels for the first and second
252+
convolution are specified.
253+
relu_leakiness: Slope of the negative part of the leaky relu.
254+
kernel_initializer: Initializer of the kernel. If none TruncatedNormal is
255+
used.
256+
use_fan_in_scaled_kernels: This rescales the kernels using the scale factor
257+
from the he initializer, which implements the equalized learning rate.
258+
use_layer_normalization: If layer normalization layers should be inserted to
259+
the network.
260+
use_intermediate_inputs: If true the model expects a list of tf.Tensors with
261+
increasing resolution starting with the starting_size up to the final
262+
resolution as input.
263+
use_antialiased_bilinear_downsampling: If true the downsampling operation is
264+
ani-aliased bilinear downsampling with a [1, 3, 3, 1] tent kernel. If
265+
false standard bilinear downsampling, i.e. average pooling is used ([1, 1]
266+
tent kernel).
267+
name: The name of the Keras model.
268+
269+
Returns:
270+
The generated discriminator keras model.
271+
"""
272+
if kernel_initializer is None:
273+
kernel_initializer = tf.keras.initializers.TruncatedNormal(
274+
mean=0.0, stddev=1.0)
275+
276+
if use_intermediate_inputs:
277+
inputs = tuple(
278+
tf.keras.Input(shape=(None, None, 3))
279+
for _ in range(len(downsampling_blocks_num_channels) + 1))
280+
tensor = inputs[-1]
281+
else:
282+
input_tensor = tf.keras.Input(shape=(None, None, 3))
283+
tensor = input_tensor
284+
285+
tensor = from_rgb(
286+
tensor,
287+
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
288+
num_channels=downsampling_blocks_num_channels[0][0],
289+
kernel_initializer=kernel_initializer,
290+
relu_leakiness=relu_leakiness)
291+
if use_layer_normalization:
292+
tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
293+
294+
for index, (channels_1,
295+
channels_2) in enumerate(downsampling_blocks_num_channels):
296+
tensor = create_conv_layer(
297+
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
298+
filters=channels_1,
299+
kernel_size=3,
300+
strides=1,
301+
padding='same',
302+
kernel_initializer=kernel_initializer)(
303+
tensor)
304+
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
305+
if use_layer_normalization:
306+
tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
307+
tensor = create_conv_layer(
308+
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
309+
filters=channels_2,
310+
kernel_size=3,
311+
strides=1,
312+
padding='same',
313+
kernel_initializer=kernel_initializer)(
314+
tensor)
315+
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
316+
if use_layer_normalization:
317+
tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
318+
if use_antialiased_bilinear_downsampling:
319+
tensor = keras_layers.Blur2D()(tensor)
320+
tensor = tf.keras.layers.AveragePooling2D()(tensor)
321+
322+
if use_intermediate_inputs:
323+
tensor = tf.keras.layers.Concatenate()([inputs[-index - 2], tensor])
324+
325+
tensor = create_conv_layer(
326+
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
327+
filters=downsampling_blocks_num_channels[-1][1],
328+
kernel_size=3,
329+
strides=1,
330+
padding='same',
331+
kernel_initializer=kernel_initializer)(
332+
tensor)
333+
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
334+
if use_layer_normalization:
335+
tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
336+
337+
tensor = create_conv_layer(
338+
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
339+
filters=downsampling_blocks_num_channels[-1][1],
340+
kernel_size=4,
341+
strides=1,
342+
padding='valid',
343+
kernel_initializer=kernel_initializer)(
344+
tensor)
345+
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
346+
if use_layer_normalization:
347+
tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
348+
349+
tensor = create_conv_layer(
350+
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
351+
multiplier=1.0,
352+
filters=1,
353+
kernel_size=1,
354+
kernel_initializer=kernel_initializer)(
355+
tensor)
356+
tensor = tf.keras.layers.Reshape((-1,))(tensor)
357+
358+
if use_intermediate_inputs:
359+
return tf.keras.Model(inputs=inputs, outputs=tensor, name=name)
360+
else:
361+
return tf.keras.Model(inputs=input_tensor, outputs=tensor, name=name)

0 commit comments

Comments
 (0)