@@ -261,6 +261,8 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
261261 List of images transformed by the predicted CDNA kernels.
262262 """
263263 batch_size = int (cdna_input .get_shape ()[0 ])
264+ height = int (prev_image .get_shape ()[1 ])
265+ width = int (prev_image .get_shape ()[2 ])
264266
265267 # Predict kernels using linear function of last hidden layer.
266268 cdna_kerns = slim .layers .fully_connected (
@@ -276,20 +278,22 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
276278 norm_factor = tf .reduce_sum (cdna_kerns , [1 , 2 , 3 ], keep_dims = True )
277279 cdna_kerns /= norm_factor
278280
279- cdna_kerns = tf .tile (cdna_kerns , [1 , 1 , 1 , color_channels , 1 ])
280- cdna_kerns = tf .split (axis = 0 , num_or_size_splits = batch_size , value = cdna_kerns )
281- prev_images = tf .split (axis = 0 , num_or_size_splits = batch_size , value = prev_image )
281+ # Treat the color channel dimension as the batch dimension since the same
282+ # transformation is applied to each color channel.
283+ # Treat the batch dimension as the channel dimension so that
284+ # depthwise_conv2d can apply a different transformation to each sample.
285+ cdna_kerns = tf .transpose (cdna_kerns , [1 , 2 , 0 , 4 , 3 ])
286+ cdna_kerns = tf .reshape (cdna_kerns , [DNA_KERN_SIZE , DNA_KERN_SIZE , batch_size , num_masks ])
287+ # Swap the batch and channel dimensions.
288+ prev_image = tf .transpose (prev_image , [3 , 1 , 2 , 0 ])
282289
283290 # Transform image.
284- transformed = []
285- for kernel , preimg in zip (cdna_kerns , prev_images ):
286- kernel = tf .squeeze (kernel )
287- if len (kernel .get_shape ()) == 3 :
288- kernel = tf .expand_dims (kernel , - 1 )
289- transformed .append (
290- tf .nn .depthwise_conv2d (preimg , kernel , [1 , 1 , 1 , 1 ], 'SAME' ))
291- transformed = tf .concat (axis = 0 , values = transformed )
292- transformed = tf .split (axis = 3 , num_or_size_splits = num_masks , value = transformed )
291+ transformed = tf .nn .depthwise_conv2d (prev_image , cdna_kerns , [1 , 1 , 1 , 1 ], 'SAME' )
292+
293+ # Transpose the dimensions to where they belong.
294+ transformed = tf .reshape (transformed , [color_channels , height , width , batch_size , num_masks ])
295+ transformed = tf .transpose (transformed , [3 , 1 , 2 , 0 , 4 ])
296+ transformed = tf .unstack (transformed , axis = - 1 )
293297 return transformed
294298
295299
0 commit comments