@@ -309,6 +309,12 @@ def build(self, input_shape):
309309 name = 'final_conv_h'
310310 )
311311
312+ self .final_conv = tfkl .Conv2D (
313+ filters = self .n_mix * self .n_component_per_mix ,
314+ kernel_size = 1 ,
315+ name = 'final_conv'
316+ )
317+
312318 def call (self , x , training = False ):
313319 # First convs
314320 v_stack = self .down_shift (self .first_conv_v (x ))
@@ -323,8 +329,8 @@ def call(self, x, training=False):
323329 residuals_h .append (h_stack )
324330 residuals_v .append (v_stack )
325331 if ds < self .n_downsampling :
326- v_stack = self .downsampling_convs_v [ds ](v_stack )
327- h_stack = self .downsampling_convs_h [ds ](h_stack )
332+ v_stack = self .downsampling_convs_v [ds ](tf . nn . relu ( v_stack ) )
333+ h_stack = self .downsampling_convs_h [ds ](tf . nn . relu ( h_stack ) )
328334 residuals_h .append (h_stack )
329335 residuals_v .append (v_stack )
330336
@@ -348,13 +354,15 @@ def call(self, x, training=False):
348354 v_stack += residuals_v .pop ()
349355 h_stack += residuals_h .pop ()
350356 if us < self .n_downsampling :
351- v_stack = self .upsampling_convs_v [us ](v_stack )
352- h_stack = self .upsampling_convs_h [us ](h_stack )
357+ v_stack = self .upsampling_convs_v [us ](tf . nn . relu ( v_stack ) )
358+ h_stack = self .upsampling_convs_h [us ](tf . nn . relu ( h_stack ) )
353359 v_stack += residuals_v .pop ()
354360 h_stack += residuals_h .pop ()
355361
356362 # Final conv
357- outputs = self .final_conv_h (h_stack ) + self .final_conv_v (v_stack )
363+ outputs = self .final_conv_h (tf .nn .relu (h_stack )) + \
364+ self .final_conv_v (tf .nn .relu (v_stack ))
365+ outputs = self .final_conv (tf .nn .relu (outputs ))
358366
359367 return outputs
360368
@@ -382,8 +390,8 @@ def sample(self, n):
382390 beta = tf .math .tanh (beta )
383391 gamma = tf .math .tanh (gamma )
384392
385- mu_g = mu_g + alpha * mu_r
386- mu_b = mu_b + beta * mu_r + gamma * mu_g
393+ # mu_g = mu_g + alpha * mu_r
394+ # mu_b = mu_b + beta * mu_r + gamma * mu_g
387395 mu = tf .stack ([mu_r , mu_g , mu_b ], axis = 2 )
388396 logvar = tf .stack ([logvar_r , logvar_g , logvar_b ], axis = 2 )
389397
@@ -397,32 +405,45 @@ def sample(self, n):
397405 # Sample colors
398406 u = tf .random .uniform (tf .shape (mu ), minval = 1e-5 , maxval = 1. - 1e-5 )
399407 x = mu + tf .exp (logvar ) * (tf .math .log (u ) - tf .math .log (1. - u ))
400- updates = tf .clip_by_value (x , - 1. , 1. )
408+
409+ # Readjust means
401410 if channels == 3 :
402- updates = updates [:, 0 , :]
411+ alpha = tf .gather (alpha , components , axis = 1 , batch_dims = 1 )
412+ beta = tf .gather (beta , components , axis = 1 , batch_dims = 1 )
413+ gamma = tf .gather (gamma , components , axis = 1 , batch_dims = 1 )
414+ x_r = x [:, 0 , 0 ]
415+ x_g = x [:, 0 , 1 ] + alpha [:, 0 ] * x_r
416+ x_b = x [:, 0 , 2 ] + beta [:, 0 ] * x_r + gamma [:, 0 ] * x_g
417+ x = tf .stack ([x_r , x_g , x_b ], axis = - 1 )
418+
419+ updates = tf .clip_by_value (x , - 1. , 1. )
403420 indices = tf .constant ([[i , h , w ] for i in range (n )])
404421 samples = tf .tensor_scatter_nd_update (samples , indices , updates )
405422
406423 return samples
407424
408425def discretized_logistic_mix_loss (y_true , y_pred ):
409- # y_true shape (batch_size, H, W, channels)
410- n_channels = y_true .shape [- 1 ]
426+ # y_true shape (batch_size, H, W, C)
427+ _ , H , W , C = y_true .shape
428+ num_pixels = float (H * W * C )
411429
412- if n_channels == 1 :
430+ if C == 1 :
413431 pi , mu , logvar = tf .split (y_pred , num_or_size_splits = 3 , axis = - 1 )
414432 mu = tf .expand_dims (mu , axis = 3 )
415433 logvar = tf .expand_dims (logvar , axis = 3 )
416- else : # n_channels == 3
434+ else : # C == 3
417435 (pi , mu_r , mu_g , mu_b , logvar_r , logvar_g , logvar_b , alpha ,
418436 beta , gamma ) = tf .split (y_pred , num_or_size_splits = 10 , axis = - 1 )
419437
420438 alpha = tf .math .tanh (alpha )
421439 beta = tf .math .tanh (beta )
422440 gamma = tf .math .tanh (gamma )
423441
424- mu_g = mu_g + alpha * mu_r
425- mu_b = mu_b + beta * mu_r + gamma * mu_g
442+ red = y_true [:,:,:,0 :1 ]
443+ green = y_true [:,:,:,1 :2 ]
444+
445+ mu_g = mu_g + alpha * red
446+ mu_b = mu_b + beta * red + gamma * green
426447 mu = tf .stack ([mu_r , mu_g , mu_b ], axis = 3 )
427448 logvar = tf .stack ([logvar_r , logvar_g , logvar_b ], axis = 3 )
428449
@@ -462,11 +483,14 @@ def log_pdf(x): # log logistic pdf
462483
463484 # Deal with edge cases
464485 log_probs = tf .where (y_true > 0.999 , log_one_minus_cdf_min , log_probs )
465- log_probs = tf .where (y_true < 0.999 , log_cdf_plus , log_probs )
486+ log_probs = tf .where (y_true < - 0.999 , log_cdf_plus , log_probs )
466487
467488 log_probs = tf .reduce_sum (log_probs , axis = 3 ) # whole pixel prob per component
468489 log_probs += tf .nn .log_softmax (pi ) # multiply by mixture components
469490 log_probs = tf .math .reduce_logsumexp (log_probs , axis = - 1 ) # add components probs
470491 log_probs = tf .reduce_sum (log_probs , axis = [1 , 2 ])
471492
472- return - log_probs
493+ # Convert to bits per dim
494+ bits_per_dim = - log_probs / num_pixels / tf .math .log (2. )
495+
496+ return bits_per_dim
0 commit comments