Skip to content

Commit a83c151

Browse files
committed
Lots of fixes
1 parent b81bbc9 commit a83c151

File tree

7 files changed

+89
-104
lines changed

7 files changed

+89
-104
lines changed

gated_pixelcnn/model.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def build(self, input_shape):
202202
stack='H',
203203
type='A',
204204
n_colors=self.n_colors,
205-
kernel_size=(1, 3),
205+
kernel_size=3,
206206
padding='SAME',
207207
filters=self.hidden_dim * self.n_colors
208208
)
@@ -212,6 +212,18 @@ def build(self, input_shape):
212212
for i in range(self.n_res)
213213
]
214214

215+
self.final_conv_h = tfkl.Conv2D(
216+
filters = self.n_output * self.n_colors,
217+
kernel_size = 1,
218+
name='final_conv_h'
219+
)
220+
221+
self.final_conv_v = tfkl.Conv2D(
222+
filters = self.n_output * self.n_colors,
223+
kernel_size = 1,
224+
name='final_conv_v'
225+
)
226+
215227
self.final_conv = tfkl.Conv2D(
216228
filters = self.n_output * self.n_colors,
217229
kernel_size = 1,
@@ -225,7 +237,10 @@ def call(self, x):
225237
for res_block in self.res_blocks:
226238
v_stack, h_stack = res_block(v_stack, h_stack)
227239

228-
h = self.final_conv(tf.nn.relu(v_stack + h_stack))
240+
h = self.final_conv_h(tf.nn.relu(h_stack)) + \
241+
self.final_conv_v(tf.nn.relu(v_stack))
242+
243+
h = self.final_conv(tf.nn.relu(h))
229244

230245
# Format output
231246
h = tf.split(h, num_or_size_splits=self.n_colors, axis=-1)
@@ -254,3 +269,13 @@ def sample(self, n):
254269
samples = tf.tensor_scatter_nd_update(samples, indices, updates)
255270

256271
return samples
272+
273+
def bits_per_dim_loss(y_true, y_pred):
274+
"""Return the bits per dim value of the predicted distribution."""
275+
B, H, W, C = y_true.shape
276+
num_pixels = float(H * W * C)
277+
log_probs = tf.math.log_softmax(y_pred, axis=-1)
278+
log_probs = tf.gather(log_probs, tf.cast(y_true, tf.int32), axis=-1, batch_dims=4)
279+
nll = - tf.reduce_sum(log_probs, axis=[1, 2, 3])
280+
bits_per_dim = nll / num_pixels / tf.math.log(2.)
281+
return bits_per_dim

gated_pixelcnn/train_mnist.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import tensorflow as tf
44
import tensorflow_datasets as tfds
55

6-
from model import GatedPixelCNN
6+
from model import GatedPixelCNN, bits_per_dim_loss
77
from utils import PlotSamplesCallback
88

99
tfk = tf.keras
1010
tfkl = tf.keras.layers
1111
AUTOTUNE = tf.data.experimental.AUTOTUNE
1212

1313
# Training parameters
14-
EPOCHS = 10
14+
EPOCHS = 75
1515
BATCH_SIZE = 64
1616
BUFFER_SIZE = 1024 # for shuffling
1717

@@ -43,9 +43,8 @@ def duplicate(element):
4343
# Define model
4444
strategy = tf.distribute.MirroredStrategy()
4545
with strategy.scope():
46-
model = GatedPixelCNN(hidden_dim=64, n_res=5)
47-
loss = tfk.losses.SparseCategoricalCrossentropy(from_logits=True)
48-
model.compile(optimizer='adam', loss=loss)
46+
model = GatedPixelCNN(hidden_dim=64, n_res=6)
47+
model.compile(optimizer='adam', loss=bits_per_dim_loss)
4948

5049
# Callbacks
5150
time = datetime.now().strftime('%Y%m%d-%H%M%S')

pixelcnn/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,13 @@ def sample(self, n):
219219
samples = tf.tensor_scatter_nd_update(samples, indices, updates)
220220

221221
return samples
222+
223+
def bits_per_dim_loss(y_true, y_pred):
224+
"""Return the bits per dim value of the predicted distribution."""
225+
B, H, W, C = y_true.shape
226+
num_pixels = float(H * W * C)
227+
log_probs = tf.math.log_softmax(y_pred, axis=-1)
228+
log_probs = tf.gather(log_probs, tf.cast(y_true, tf.int32), axis=-1, batch_dims=4)
229+
nll = - tf.reduce_sum(log_probs, axis=[1, 2, 3])
230+
bits_per_dim = nll / num_pixels / tf.math.log(2.)
231+
return bits_per_dim

pixelcnn/train_mnist.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import tensorflow as tf
44
import tensorflow_datasets as tfds
55

6-
from model import PixelCNN
6+
from model import PixelCNN, bits_per_dim_loss
77
from utils import PlotSamplesCallback
88

99
tfk = tf.keras
1010
tfkl = tf.keras.layers
1111
AUTOTUNE = tf.data.experimental.AUTOTUNE
1212

1313
# Training parameters
14-
EPOCHS = 10
14+
EPOCHS = 50
1515
BATCH_SIZE = 64
1616
BUFFER_SIZE = 1024 # for shuffling
1717

@@ -43,9 +43,8 @@ def duplicate(element):
4343
# Define model
4444
strategy = tf.distribute.MirroredStrategy()
4545
with strategy.scope():
46-
model = PixelCNN(hidden_dim=32, n_res=3)
47-
loss = tfk.losses.SparseCategoricalCrossentropy(from_logits=True)
48-
model.compile(optimizer='adam', loss=loss)
46+
model = PixelCNN(hidden_dim=64, n_res=6)
47+
model.compile(optimizer='adam', loss=bits_per_dim_loss)
4948

5049
# Callbacks
5150
time = datetime.now().strftime('%Y%m%d-%H%M%S')

pixelcnn_plus/model.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,12 @@ def build(self, input_shape):
309309
name='final_conv_h'
310310
)
311311

312+
self.final_conv = tfkl.Conv2D(
313+
filters = self.n_mix * self.n_component_per_mix,
314+
kernel_size = 1,
315+
name='final_conv'
316+
)
317+
312318
def call(self, x, training=False):
313319
# First convs
314320
v_stack = self.down_shift(self.first_conv_v(x))
@@ -323,8 +329,8 @@ def call(self, x, training=False):
323329
residuals_h.append(h_stack)
324330
residuals_v.append(v_stack)
325331
if ds < self.n_downsampling:
326-
v_stack = self.downsampling_convs_v[ds](v_stack)
327-
h_stack = self.downsampling_convs_h[ds](h_stack)
332+
v_stack = self.downsampling_convs_v[ds](tf.nn.relu(v_stack))
333+
h_stack = self.downsampling_convs_h[ds](tf.nn.relu(h_stack))
328334
residuals_h.append(h_stack)
329335
residuals_v.append(v_stack)
330336

@@ -348,13 +354,15 @@ def call(self, x, training=False):
348354
v_stack += residuals_v.pop()
349355
h_stack += residuals_h.pop()
350356
if us < self.n_downsampling:
351-
v_stack = self.upsampling_convs_v[us](v_stack)
352-
h_stack = self.upsampling_convs_h[us](h_stack)
357+
v_stack = self.upsampling_convs_v[us](tf.nn.relu(v_stack))
358+
h_stack = self.upsampling_convs_h[us](tf.nn.relu(h_stack))
353359
v_stack += residuals_v.pop()
354360
h_stack += residuals_h.pop()
355361

356362
# Final conv
357-
outputs = self.final_conv_h(h_stack) + self.final_conv_v(v_stack)
363+
outputs = self.final_conv_h(tf.nn.relu(h_stack)) + \
364+
self.final_conv_v(tf.nn.relu(v_stack))
365+
outputs = self.final_conv(tf.nn.relu(outputs))
358366

359367
return outputs
360368

@@ -382,8 +390,8 @@ def sample(self, n):
382390
beta = tf.math.tanh(beta)
383391
gamma = tf.math.tanh(gamma)
384392

385-
mu_g = mu_g + alpha * mu_r
386-
mu_b = mu_b + beta * mu_r + gamma * mu_g
393+
# mu_g = mu_g + alpha * mu_r
394+
# mu_b = mu_b + beta * mu_r + gamma * mu_g
387395
mu = tf.stack([mu_r, mu_g, mu_b], axis=2)
388396
logvar = tf.stack([logvar_r, logvar_g, logvar_b], axis=2)
389397

@@ -397,32 +405,45 @@ def sample(self, n):
397405
# Sample colors
398406
u = tf.random.uniform(tf.shape(mu), minval=1e-5, maxval=1. - 1e-5)
399407
x = mu + tf.exp(logvar) * (tf.math.log(u) - tf.math.log(1. - u))
400-
updates = tf.clip_by_value(x, -1., 1.)
408+
409+
# Readjust means
401410
if channels == 3:
402-
updates = updates[:, 0, :]
411+
alpha = tf.gather(alpha, components, axis=1, batch_dims=1)
412+
beta = tf.gather(beta, components, axis=1, batch_dims=1)
413+
gamma = tf.gather(gamma, components, axis=1, batch_dims=1)
414+
x_r = x[:, 0, 0]
415+
x_g = x[:, 0, 1] + alpha[:, 0] * x_r
416+
x_b = x[:, 0, 2] + beta[:, 0] * x_r + gamma[:, 0] * x_g
417+
x = tf.stack([x_r, x_g, x_b], axis=-1)
418+
419+
updates = tf.clip_by_value(x, -1., 1.)
403420
indices = tf.constant([[i, h, w] for i in range(n)])
404421
samples = tf.tensor_scatter_nd_update(samples, indices, updates)
405422

406423
return samples
407424

408425
def discretized_logistic_mix_loss(y_true, y_pred):
409-
# y_true shape (batch_size, H, W, channels)
410-
n_channels = y_true.shape[-1]
426+
# y_true shape (batch_size, H, W, C)
427+
_, H, W, C = y_true.shape
428+
num_pixels = float(H * W * C)
411429

412-
if n_channels == 1:
430+
if C == 1:
413431
pi, mu, logvar = tf.split(y_pred, num_or_size_splits=3, axis=-1)
414432
mu = tf.expand_dims(mu, axis=3)
415433
logvar = tf.expand_dims(logvar, axis=3)
416-
else: # n_channels == 3
434+
else: # C == 3
417435
(pi, mu_r, mu_g, mu_b, logvar_r, logvar_g, logvar_b, alpha,
418436
beta, gamma) = tf.split(y_pred, num_or_size_splits=10, axis=-1)
419437

420438
alpha = tf.math.tanh(alpha)
421439
beta = tf.math.tanh(beta)
422440
gamma = tf.math.tanh(gamma)
423441

424-
mu_g = mu_g + alpha * mu_r
425-
mu_b = mu_b + beta * mu_r + gamma * mu_g
442+
red = y_true[:,:,:,0:1]
443+
green = y_true[:,:,:,1:2]
444+
445+
mu_g = mu_g + alpha * red
446+
mu_b = mu_b + beta * red + gamma * green
426447
mu = tf.stack([mu_r, mu_g, mu_b], axis=3)
427448
logvar = tf.stack([logvar_r, logvar_g, logvar_b], axis=3)
428449

@@ -462,11 +483,14 @@ def log_pdf(x): # log logistic pdf
462483

463484
# Deal with edge cases
464485
log_probs = tf.where(y_true > 0.999, log_one_minus_cdf_min, log_probs)
465-
log_probs = tf.where(y_true < 0.999, log_cdf_plus, log_probs)
486+
log_probs = tf.where(y_true < -0.999, log_cdf_plus, log_probs)
466487

467488
log_probs = tf.reduce_sum(log_probs, axis=3) # whole pixel prob per component
468489
log_probs += tf.nn.log_softmax(pi) # multiply by mixture components
469490
log_probs = tf.math.reduce_logsumexp(log_probs, axis=-1) # add components probs
470491
log_probs = tf.reduce_sum(log_probs, axis=[1, 2])
471492

472-
return -log_probs
493+
# Convert to bits per dim
494+
bits_per_dim = -log_probs / num_pixels / tf.math.log(2.)
495+
496+
return bits_per_dim

pixelcnn_plus/train_mnist.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tensorflow_datasets as tfds
55

66
from model import PixelCNNplus, discretized_logistic_mix_loss
7-
from utils import PlotSamplesCallback, PlotReconstructionCallback
7+
from utils import PlotSamplesCallback
88

99
tfk = tf.keras
1010
tfkl = tf.keras.layers
@@ -50,8 +50,7 @@ def duplicate(element):
5050
time = datetime.now().strftime('%Y%m%d-%H%M%S')
5151
log_dir = os.path.join('.', 'logs', 'pixelcnn++', time)
5252
tensorboard_clbk = tfk.callbacks.TensorBoard(log_dir=log_dir)
53-
sample_clbk = PlotSamplesCallback(logdir=log_dir, period=5)
54-
reconstruction_clbk = PlotReconstructionCallback(logdir=log_dir, test_ds=test_ds)
53+
sample_clbk = PlotSamplesCallback(logdir=log_dir, period=1)
5554
callbacks = [tensorboard_clbk, sample_clbk, reconstruction_clbk]
5655

5756
# Fit

pixelcnn_plus/utils.py

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def plot_to_image(figure):
2323

2424
class PlotSamplesCallback(tfk.callbacks.Callback):
2525
"""Plot `nex` reconstructed image to tensorboard."""
26-
def __init__(self, logdir: str, nex: int=4, period: int=5):
26+
def __init__(self, logdir: str, nex: int=4, period: int=1):
2727
super(PlotSamplesCallback, self).__init__()
2828
logdir = os.path.join(logdir, 'samples')
2929
self.file_writer = tf.summary.create_file_writer(logdir=logdir)
@@ -58,74 +58,3 @@ def on_epoch_end(self, epoch, logs=None):
5858
step=epoch,
5959
max_outputs=self.nex
6060
)
61-
62-
63-
class PlotReconstructionCallback(tfk.callbacks.Callback):
64-
"""Plot `nex` reconstructed image to tensorboard."""
65-
def __init__(self, logdir: str, test_ds: tf.data.Dataset, nex: int=4):
66-
super(PlotReconstructionCallback, self).__init__()
67-
logdir = os.path.join(logdir, 'reconstructions')
68-
self.file_writer = tf.summary.create_file_writer(logdir=logdir)
69-
self.nex = nex
70-
self.test_ds = test_ds.map(lambda x, y: x).unbatch().batch(nex)
71-
self.test_it = iter(self.test_ds)
72-
73-
def get_next_images(self):
74-
try:
75-
next_images = next(self.test_it)
76-
except StopIteration:
77-
self.test_it = iter(self.test_ds)
78-
next_images = next(self.test_it)
79-
return next_images
80-
81-
def plot_img_reconstruction(self, image, reconstruction):
82-
fig, ax = plt.subplots(nrows=1, ncols=2)
83-
84-
if image.shape[-1] == 1:
85-
image = tf.squeeze(image, axis=-1)
86-
reconstruction = tf.squeeze(reconstruction, axis=-1)
87-
88-
ax[0].imshow(image, vmin=-1., vmax=1., cmap=plt.cm.Greys)
89-
ax[0].set_title('Image')
90-
ax[0].axis('off')
91-
92-
ax[1].imshow(reconstruction, vmin=-1., vmax=1., cmap=plt.cm.Greys)
93-
ax[1].set_title('Reconstruction')
94-
ax[1].axis('off')
95-
96-
return fig
97-
98-
def get_means(self, logits):
99-
pi, mu, _ = tf.split(logits, num_or_size_splits=3, axis=-1)
100-
nex, height, width, n_mix = pi.shape
101-
102-
pi = tf.reshape(pi, shape=(-1, n_mix))
103-
# components = tf.random.categorical(logits=pi, num_samples=1)
104-
components = tf.argmax(pi, axis=-1)[:, None]
105-
106-
mu = tf.reshape(pi, shape=(-1, n_mix))
107-
mu = tf.gather(mu, components, axis=1, batch_dims=1)
108-
mu = tf.reshape(mu, (nex, height, width, 1))
109-
mu = tf.clip_by_value(mu, -1., 1.)
110-
111-
return mu
112-
113-
114-
def on_epoch_end(self, epoch, logs=None):
115-
images = self.get_next_images()
116-
logits = self.model(images)
117-
reconstructions = self.get_means(logits)
118-
119-
imgs = []
120-
for i in range(self.nex):
121-
fig = self.plot_img_reconstruction(images[i], reconstructions[i])
122-
imgs.append(plot_to_image(fig))
123-
124-
imgs = tf.concat(imgs, axis=0)
125-
with self.file_writer.as_default():
126-
tf.summary.image(
127-
name='Reconstructions',
128-
data=imgs,
129-
step=epoch,
130-
max_outputs=self.nex
131-
)

0 commit comments

Comments
 (0)