Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Hiroki Sakuma authored and Hiroki Sakuma committed May 17, 2019
1 parent 0b12b90 commit f31e797
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 40 deletions.
25 changes: 15 additions & 10 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@ def __init__(self, generator, discriminator, real_input_fn, fake_input_fn, spect
fake_magnitude_spectrograms, fake_instantaneous_frequencies = tf.unstack(fake_images, axis=1)
fake_waveforms = spectral_ops.convert_to_waveform(fake_magnitude_spectrograms, fake_instantaneous_frequencies, **spectral_params)

real_logits = discriminator(real_images, labels)
fake_logits = discriminator(fake_images, labels)
real_features, real_logits = discriminator(real_images, labels)
fake_features, fake_logits = discriminator(fake_images, labels)

# label conditioning from
# [Which Training Methods for GANs do actually Converge?]
# (https://arxiv.org/pdf/1801.04406.pdf)
real_logits = tf.gather_nd(real_logits, indices=tf.where(labels))
fake_logits = tf.gather_nd(fake_logits, indices=tf.where(labels))

# non-saturating loss
discriminator_losses = tf.nn.softplus(-real_logits)
Expand Down Expand Up @@ -92,6 +98,10 @@ def __init__(self, generator, discriminator, real_input_fn, fake_input_fn, spect
self.fake_images = fake_images
self.real_labels = labels
self.fake_labels = labels
self.real_features = real_features
self.fake_features = fake_features
self.real_logits = real_logits
self.fake_logits = fake_logits
self.generator_loss = generator_loss
self.discriminator_loss = discriminator_loss
self.generator_train_op = generator_train_op
Expand Down Expand Up @@ -213,26 +223,21 @@ def generator():
while not session.should_stop():
try:
yield session.run([
real_features,
fake_features,
real_logits,
fake_logits,
self.real_features,
self.fake_features,
self.real_magnitude_spectrograms,
self.fake_magnitude_spectrograms
])
except tf.errors.OutOfRangeError:
break

real_features, fake_features, real_logits, fake_logits, \
real_magnitude_spectrograms, fake_magnitude_spectrograms = map(np.concatenate, zip(*generator()))
real_features, fake_features, real_magnitude_spectrograms, fake_magnitude_spectrograms = map(np.concatenate, zip(*generator()))

real_magnitude_spectrograms = np.reshape(real_magnitude_spectrograms, [real_magnitude_spectrograms.shape[0], -1])
fake_magnitude_spectrograms = np.reshape(real_magnitude_spectrograms, [fake_magnitude_spectrograms.shape[0], -1])

return dict(
frechet_inception_distance=metrics.frechet_inception_distance(real_features, fake_features),
real_inception_score=metrics.inception_score(real_logits),
fake_inception_score=metrics.inception_score(fake_logits),
num_different_bins=metrics.num_different_bins(real_magnitude_spectrograms, fake_magnitude_spectrograms)
)

Expand Down
37 changes: 7 additions & 30 deletions networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,22 +192,17 @@ def conv_block(inputs, depth, reuse=tf.AUTO_REUSE):
scale_weight=True
)
inputs = tf.nn.leaky_relu(inputs)
features = inputs
with tf.variable_scope("logits"):
# label conditioning from
# [Which Training Methods for GANs do actually Converge?]
# (https://arxiv.org/pdf/1801.04406.pdf)
inputs = dense(
inputs=inputs,
units=labels.shape[1],
use_bias=True,
variance_scale=1.0,
scale_weight=True
)
inputs = tf.gather_nd(
params=inputs,
indices=tf.where(labels)
)
return inputs
logits = inputs
return features, logits
else:
with tf.variable_scope("conv"):
inputs = conv2d(
Expand Down Expand Up @@ -314,17 +309,13 @@ def residual_block(inputs, filters, strides, projection_shortcut, groups):
(https://arxiv.org/pdf/1603.05027.pdf)
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
'''

shortcut = inputs

with tf.variable_scope("group_normalization_1st"):
inputs = group_normalization(
inputs=inputs,
groups=groups
)

inputs = tf.nn.relu(inputs)

if projection_shortcut:
with tf.variable_scope("projection_shortcut"):
shortcut = conv2d(
Expand All @@ -336,7 +327,6 @@ def residual_block(inputs, filters, strides, projection_shortcut, groups):
variance_scale=2.0,
apply_weight_standardization=True
)

with tf.variable_scope("conv_1st"):
inputs = conv2d(
inputs=inputs,
Expand All @@ -347,15 +337,12 @@ def residual_block(inputs, filters, strides, projection_shortcut, groups):
variance_scale=2.0,
apply_weight_standardization=True
)

with tf.variable_scope("group_normalization_2nd"):
inputs = group_normalization(
inputs=inputs,
groups=groups
)

inputs = tf.nn.relu(inputs)

with tf.variable_scope("conv_2nd"):
inputs = conv2d(
inputs=inputs,
Expand All @@ -366,13 +353,10 @@ def residual_block(inputs, filters, strides, projection_shortcut, groups):
variance_scale=2.0,
apply_weight_standardization=True
)

inputs += shortcut

return inputs

with tf.variable_scope(name, reuse=reuse):

if self.conv_param:
with tf.variable_scope("conv"):
inputs = conv2d(
Expand All @@ -384,16 +368,13 @@ def residual_block(inputs, filters, strides, projection_shortcut, groups):
variance_scale=2.0,
apply_weight_standardization=True
)

if self.pool_param:
inputs = max_pooling2d(
inputs=inputs,
kernel_size=self.pool_param.kernel_size,
strides=self.pool_param.strides
)

for i, residual_param in enumerate(self.residual_params):

for j in range(residual_param.blocks)[:1]:
with tf.variable_scope(f"residual_block_{i}_{j}"):
inputs = residual_block(
Expand All @@ -403,7 +384,6 @@ def residual_block(inputs, filters, strides, projection_shortcut, groups):
projection_shortcut=True,
groups=self.groups
)

for j in range(residual_param.blocks)[1:]:
with tf.variable_scope(f"residual_block_{i}_{j}"):
inputs = residual_block(
Expand All @@ -413,24 +393,21 @@ def residual_block(inputs, filters, strides, projection_shortcut, groups):
projection_shortcut=False,
groups=self.groups
)

with tf.variable_scope("group_normalization"):
inputs = group_normalization(
inputs=inputs,
groups=self.groups
)

inputs = tf.nn.relu(inputs)

features = tf.reduce_mean(inputs, axis=[2, 3])

inputs = tf.reduce_mean(inputs, axis=[2, 3])
features = inputs
with tf.variable_scope("logits"):
logits = dense(
inputs = dense(
inputs=features,
units=self.classes,
use_bias=True,
variance_scale=1.0,
apply_weight_standardization=False
)

logits = inputs
return features, logits

0 comments on commit f31e797

Please sign in to comment.