Skip to content

Tensorflow WARNING: tensorflow:Gradients do not exist for variables #47469

@JCL823

Description

@JCL823

I am trying to implement the VQ-VAE using TensorFlow 2.X.
I borrow the code of the VQ-VAE training from this github which is written in TF 2x. and I'm working on merging the Keras code of PixelCNN training & sampling from this tutorial to the VQ-VAE TF 2.X code.

For the training of VQ-VAE, I think it was fine. But when I want to get the quantization vectors from the trained VQ-VAE and use the quantization vectors to further train the PixelCNN. I always get a warning:

WARNING:tensorflow:Gradients do not exist for variables
['conv2d_block/batch_normalization/gamma:0',
'conv2d_block/batch_normalization/beta:0',
'conv2d_block/conv2d/kernel:0', 'conv2d_block/conv2d/bias:0',
'conv2d_block_1/batch_normalization_1/gamma:0',
'conv2d_block_1/batch_normalization_1/beta:0',
'conv2d_block_1/conv2d_1/kernel:0', 'conv2d_block_1/conv2d_1/bias:0',
'conv2d_block_2/batch_normalization_2/gamma:0',
'conv2d_block_2/batch_normalization_2/beta:0',
'conv2d_block_2/conv2d_2/kernel:0', 'conv2d_block_2/conv2d_2/bias:0',
'conv2d_block_3/batch_normalization_3/gamma:0',
'conv2d_block_3/batch_normalization_3/beta:0',
'conv2d_block_3/conv2d_3/kernel:0', 'conv2d_block_3/conv2d_3/bias:0',
'conv2d_block_4/batch_normalization_4/gamma:0',
'conv2d_block_4/batch_normalization_4/beta:0',
'conv2d_block_4/conv2d_4/kernel:0', 'conv2d_block_4/conv2d_4/bias:0',
'conv2d_block_5/batch_normalization_5/gamma:0',
'conv2d_block_5/batch_normalization_5/beta:0',
'conv2d_block_5/conv2d_5/kernel:0', 'conv2d_block_5/conv2d_5/bias:0',
'conv2d_block_6/batch_normalization_6/gamma:0',
'conv2d_block_6/batch_normalization_6/beta:0',
'conv2d_block_6/conv2d_6/kernel:0', 'conv2d_block_6/conv2d_6/bias:0',
'conv2d_block_7/batch_normalization_7/gamma:0',
'conv2d_block_7/batch_normalization_7/beta:0',
'conv2d_block_7/conv2d_7/kernel:0', 'conv2d_block_7/conv2d_7/bias:0',
'conv2d_block_8/batch_normalization_8/gamma:0',
'conv2d_block_8/batch_normalization_8/beta:0',
'conv2d_block_8/conv2d_8/kernel:0', 'conv2d_block_8/conv2d_8/bias:0',
'conv2d_block_9/batch_normalization_9/gamma:0',
'conv2d_block_9/batch_normalization_9/beta:0',
'conv2d_block_9/conv2d_9/kernel:0', 'conv2d_block_9/conv2d_9/bias:0',
'conv2d_block_10/instance_normalization/scale:0',
'conv2d_block_10/instance_normalization/offset:0',
'conv2d_block_10/conv2d_10/kernel:0',
'conv2d_block_10/conv2d_10/bias:0',
'conv2d_block_11/instance_normalization_1/scale:0',
'conv2d_block_11/instance_normalization_1/offset:0',
'conv2d_block_11/conv2d_11/kernel:0',
'conv2d_block_11/conv2d_11/bias:0',
'conv2d_block_12/instance_normalization_2/scale:0',
'conv2d_block_12/instance_normalization_2/offset:0',
'conv2d_block_12/conv2d_12/kernel:0',
'conv2d_block_12/conv2d_12/bias:0',
'conv2d_block_13/instance_normalization_3/scale:0',
'conv2d_block_13/instance_normalization_3/offset:0',
'conv2d_block_13/conv2d_13/kernel:0',
'conv2d_block_13/conv2d_13/bias:0',
'conv2d_block_14/instance_normalization_4/scale:0',
'conv2d_block_14/instance_normalization_4/offset:0',
'conv2d_block_14/conv2d_14/kernel:0',
'conv2d_block_14/conv2d_14/bias:0',
'conv2d_block_15/instance_normalization_5/scale:0',
'conv2d_block_15/instance_normalization_5/offset:0',
'conv2d_block_15/conv2d_15/kernel:0',
'conv2d_block_15/conv2d_15/bias:0',
'conv2d_block_16/instance_normalization_6/scale:0',
'conv2d_block_16/instance_normalization_6/offset:0',
'conv2d_block_16/conv2d_16/kernel:0',
'conv2d_block_16/conv2d_16/bias:0',
'conv2d_block_17/instance_normalization_7/scale:0',
'conv2d_block_17/instance_normalization_7/offset:0',
'conv2d_block_17/conv2d_17/kernel:0',
'conv2d_block_17/conv2d_17/bias:0',
'conv2d_block_18/instance_normalization_8/scale:0',
'conv2d_block_18/instance_normalization_8/offset:0',
'conv2d_block_18/conv2d_18/kernel:0',
'conv2d_block_18/conv2d_18/bias:0',
'conv2d_block_19/conv2d_19/kernel:0',
'conv2d_block_19/conv2d_19/bias:0', 'conv2d_20/kernel:0',
'conv2d_20/bias:0', 'v_masked_conv_1/W_v:0', 'h_masked_conv_1/W_h:0',
'v_masked_conv_2/W_v:0', 'h_masked_conv_2/W_h:0',
'v_masked_conv_3/W_v:0', 'h_masked_conv_3/W_h:0',
'v_masked_conv_4/W_v:0', 'h_masked_conv_4/W_h:0',
'v_masked_conv_5/W_v:0', 'h_masked_conv_5/W_h:0',
'v_masked_conv_6/W_v:0', 'h_masked_conv_6/W_h:0',
'v_masked_conv_7/W_v:0', 'h_masked_conv_7/W_h:0',
'v_masked_conv_8/W_v:0', 'h_masked_conv_8/W_h:0',
'v_masked_conv_9/W_v:0', 'h_masked_conv_9/W_h:0',
'v_masked_conv_10/W_v:0', 'h_masked_conv_10/W_h:0',
'v_masked_conv_11/W_v:0', 'h_masked_conv_11/W_h:0',
'v_masked_conv_12/W_v:0', 'h_masked_conv_12/W_h:0'] when minimizing
the loss.

The following is what I have for my codes.

    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()  
    x_train = (x_train[:200]/ 255.).astype("float32")
    x_test = (x_test[:200] / 255.).astype("float32")
    
    #training vqvae 
    model = Autoencoder(config)
    model.fit(x_train, config)
    _, vq_return = model.call(x_train, training=False)
    quantize = vq_return['quantize']

where the fit function is I newly defined for the Autoencoder model (actually it should be the VQ-VAE mode), the rests are the same as what it has in the git.

For the PixelCNN part, I modified the above Keras tutorial to have Lambda layer as classes:

class SamplingLayer(tf.keras.layers.Layer):
    def __init__(self, model, **kwargs):  
        super(SamplingLayer, self).__init__()
        self.model = model
    def call(self, encoding_indices, training=False):
        vq = self.model.vq
        return vq.quantize(encoding_indices)

class CodesSampler(tf.keras.Model):
    def __init__(self, **kwargs):
        super(CodesSampler, self).__init__()

    def call(self, model, size, training=False):
        sampling_layer = SamplingLayer(model)
        indices = tf.keras.layers.Input(shape=(size, size), name='codes_sampler_inputs', dtype='int32')
        z_q = sampling_layer(indices)
        codes_sampler = tf.keras.Model(inputs=indices, outputs=z_q, name="codes_sampler")
        return codes_sampler

and the rest of Pixel CNN model is basically the same as the tutorial

'''Learning a prior over the latent space'''
class GateLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(GateLayer, self).__init__()
        # self.name = name
    def call(self, inputs):
        """Gated activations"""
        x, y = tf.split(inputs, 2, axis=-1)
        return Kb.tanh(x) * Kb.sigmoid(y)


class MaskedConv2D(tf.keras.layers.Layer):
    """Masked convolution"""
    def __init__(self, kernel_size, out_dim, direction, mode, **kwargs):
        self.direction = direction     # Horizontal or vertical
        self.mode = mode               # Mask type "a" or "b"
        self.kernel_size = kernel_size
        self.out_dim = out_dim
        super(MaskedConv2D, self).__init__(**kwargs)
    
    def build(self, input_shape):   
        filter_mid_y = self.kernel_size[0] // 2
        filter_mid_x = self.kernel_size[1] // 2        
        in_dim = int(input_shape[-1])
        w_shape = [self.kernel_size[0], self.kernel_size[1], in_dim, self.out_dim]
        mask_filter = np.ones(w_shape, dtype=np.float32)
        # Build the mask
        if self.direction == "h":
            mask_filter[filter_mid_y + 1:, :, :, :] = 0.
            mask_filter[filter_mid_y, filter_mid_x + 1:, :, :] = 0.
        elif self.direction == "v":
            if self.mode == 'a':
                mask_filter[filter_mid_y:, :, :, :] = 0.
            elif self.mode == 'b':
                mask_filter[filter_mid_y+1:, :, :, :] = 0.0
        if self.mode == 'a':
            mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.0
        # Create convolution layer parameters with masked kernel
        self.W = mask_filter * self.add_weight("W_{}".format(self.direction), w_shape, trainable=True)
        self.b = self.add_weight("v_b", [self.out_dim,], trainable=True)
    
    def call(self, inputs):
        return tf.keras.backend.conv2d(inputs, self.W, strides=(1, 1)) + self.b

    
def gated_masked_conv2d(v_stack_in, h_stack_in, out_dim, kernel, mask='b', residual=True, i=0):
    """Basic Gated-PixelCNN block. 
       This is an improvement over PixelRNN to avoid "blind spots", i.e. pixels missingt from the
       field of view. It works by having two parallel stacks, for the vertical and horizontal direction, 
       each being masked  to only see the appropriate context pixels.
    """
    kernel_size = (kernel // 2 + 1, kernel)
    padding = (kernel // 2, kernel // 2)
    v_gate = GateLayer(name="v_gate_{}".format(i))
    v_stack = tf.keras.layers.ZeroPadding2D(padding=padding, name="v_pad_{}".format(i))(v_stack_in)
    v_stack = MaskedConv2D(kernel_size, out_dim * 2, "v", mask, name="v_masked_conv_{}".format(i))(v_stack)
    v_stack = v_stack[:, :int(v_stack_in.get_shape()[-3]), :, :]
    v_stack_out = v_gate(v_stack)
    
    kernel_size = (1, kernel // 2 + 1)
    padding = (0, kernel // 2)
    h_gate = GateLayer(name="h_gate_{}".format(i))
    h_stack = tf.keras.layers.ZeroPadding2D(padding=padding, name="h_pad_{}".format(i))(h_stack_in)
    h_stack = MaskedConv2D(kernel_size, out_dim * 2, "h", mask, name="h_masked_conv_{}".format(i))(h_stack)
    h_stack = h_stack[:, :, :int(h_stack_in.get_shape()[-2]), :]
    h_stack_1 = tf.keras.layers.Conv2D(filters=out_dim * 2, kernel_size=1, strides=(1, 1), name="v_to_h_{}".format(i))(v_stack)
    h_stack_out = h_gate(h_stack + h_stack_1)
    
    h_stack_out =  tf.keras.layers.Conv2D(filters=out_dim, kernel_size=1, strides=(1, 1), name="res_conv_{}".format(i))(h_stack_out)
    if residual:
        h_stack_out += h_stack_in
    return v_stack_out, h_stack_out


def build_pixelcnn(codes_sampler, k, size, num_layers, num_feature_maps=32):
    pixelcnn_prior_inputs = tf.keras.layers.Input(shape=(size, size), name='pixelcnn_prior_inputs', dtype=tf.int32)
    z_q = codes_sampler(pixelcnn_prior_inputs, size) # maps indices (z_train in the implementation) to the actual codebook
    
    v_stack_in, h_stack_in = z_q, z_q
    for i in range(num_layers):
        mask = 'b' if i > 0 else 'a'
        kernel_size = 3 if i > 0 else 7
        residual = True if i > 0 else False
        v_stack_in, h_stack_in = gated_masked_conv2d(v_stack_in, h_stack_in, num_feature_maps,
                                                     kernel=kernel_size, residual=residual, i=i + 1)

    fc1 = tf.keras.layers.Conv2D(filters=num_feature_maps, kernel_size=1, name="fc1")(h_stack_in)
    fc2 = tf.keras.layers.Conv2D(filters=k, kernel_size=1, name="fc2")(fc1) 
    # outputs logits for probabilities of codebook indices for each cell

    pixelcnn_prior = tf.keras.Model(inputs=pixelcnn_prior_inputs, outputs=fc2, name='pixelcnn-prior')

    # Distribution to sample from the pixelcnn
    dist = tfp.distributions.Categorical(logits=fc2)
    sampled = dist.sample()
    prior_sampler = tf.keras.Model(inputs=pixelcnn_prior_inputs, outputs=sampled, name='pixelcnn-prior-sampler')
    return pixelcnn_prior, prior_sampler


##%%time
# Train the PixelCNN and monitor prediction accuracy
def accuracy(y_true, y_pred):
    size = int(y_pred.get_shape()[-2])
    k = int(y_pred.get_shape()[-1])
    y_true = tf.reshape(y_true, (-1, size * size))
    y_pred = tf.reshape(y_pred, (-1, size * size, k))
    return Kb.cast(Kb.equal(y_true, Kb.cast(Kb.argmax(y_pred, axis=-1), Kb.floatx())), Kb.floatx())

The last is training PixelCNN using the quantized vectors:

z_train = model.call(x_train, training=False)[1]['encoding_indices']
pixelcnn_prior, prior_sampler = build_pixelcnn(codes_sampler, NUM_LATENT_K, SIZE, 
                                               PIXELCNN_NUM_BLOCKS, PIXELCNN_NUM_FEATURE_MAPS)
pixelcnn_prior.summary()
pixelcnn_prior.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[accuracy],
                       optimizer=tf.keras.optimizers.Adam(PIXELCNN_LEARNING_RATE))
prior_history = pixelcnn_prior.fit(z_train, z_train, epochs=PIXELCNN_NUM_EPOCHS, 
                                   batch_size=PIXELCNN_BATCH_SIZE, verbose=1)

It's also weird that when I plot the summary tables of the Keras tutorial and my modified TF2.X codes, there are different in: codes_sampler and Non-trainable params: 8,064.

The top is the summary of the Keras tutorial with most of the common parts omitted:

Model: "pixelcnn-prior"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
pixelcnn_prior_inputs (InputLay [(None, 8, 8)]       0                                            
__________________________________________________________________________________________________
codes_sampler (Model)           (None, 8, 8, 256)    0           pixelcnn_prior_inputs[0][0]      
__________________________________________________________________________________________________
v_pad_1 (ZeroPadding2D)         (None, 14, 14, 256)  0           codes_sampler[1][0]              
__________________________________________________________________________________________________
h_pad_1 (ZeroPadding2D)         (None, 8, 14, 256)   0           codes_sampler[1][0]              
__________________________________________________________________________________________________
v_masked_conv_1 (MaskedConv2D)  (None, 11, 8, 64)    458816      v_pad_1[0][0]                    
__________________________________________________________________________________________________
h_masked_conv_1 (MaskedConv2D)  (None, 8, 11, 64)    65600       h_pad_1[0][0]                    
__________________________________________________________________________________________________
tf_op_layer_strided_slice_4 (Te [(None, 8, 8, 64)]   0           v_masked_conv_1[0][0]            
__________________________________________________________________________________________________
tf_op_layer_strided_slice_5 (Te [(None, 8, 8, 64)]   0           h_masked_conv_1[0][0]            
__________________________________________________________________________________________________
v_to_h_1 (Conv2D)               (None, 8, 8, 64)     4160        tf_op_layer_strided_slice_4[0][0]
__________________________________________________________________________________________________
tf_op_layer_add (TensorFlowOpLa [(None, 8, 8, 64)]   0           tf_op_layer_strided_slice_5[0][0]
                                                                 v_to_h_1[0][0]                   
__________________________________________________________________________________________________
h_gate_1 (Lambda)               (None, 8, 8, 32)     0           tf_op_layer_add[0][0]            
__________________________________________________________________________________________________
v_gate_1 (Lambda)               (None, 8, 8, 32)     0           tf_op_layer_strided_slice_4[0][0]

and

==================================================================================================
Total params: 786,592
Trainable params: 786,592
Non-trainable params: 0

For my modified TF2.X code,

Model: "pixelcnn-prior"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
pixelcnn_prior_inputs (InputLay [(None, 4, 4)]       0                                            
__________________________________________________________________________________________________
codes_sampler (Functional)      (None, 4, 4, 256)    21824387    pixelcnn_prior_inputs[0][0]      
__________________________________________________________________________________________________
v_pad_1 (ZeroPadding2D)         (None, 10, 10, 256)  0           codes_sampler[0][0]              
__________________________________________________________________________________________________
h_pad_1 (ZeroPadding2D)         (None, 4, 10, 256)   0           codes_sampler[0][0]              
__________________________________________________________________________________________________
v_masked_conv_1 (MaskedConv2D)  (None, 7, 4, 64)     458816      v_pad_1[0][0]                    
__________________________________________________________________________________________________
h_masked_conv_1 (MaskedConv2D)  (None, 4, 7, 64)     65600       h_pad_1[0][0]                    
__________________________________________________________________________________________________
tf.__operators__.getitem (Slici (None, 4, 4, 64)     0           v_masked_conv_1[0][0]            
__________________________________________________________________________________________________
tf.__operators__.getitem_1 (Sli (None, 4, 4, 64)     0           h_masked_conv_1[0][0]            
__________________________________________________________________________________________________
v_to_h_1 (Conv2D)               (None, 4, 4, 64)     4160        tf.__operators__.getitem[0][0]   
__________________________________________________________________________________________________
tf.__operators__.add (TFOpLambd (None, 4, 4, 64)     0           tf.__operators__.getitem_1[0][0] 
                                                                 v_to_h_1[0][0]                   
__________________________________________________________________________________________________
gate_layer_1 (GateLayer)        (None, 4, 4, 32)     0           tf.__operators__.add[0][0]       
__________________________________________________________________________________________________
gate_layer (GateLayer)          (None, 4, 4, 32)     0           tf.__operators__.getitem[0][0]   

and

==================================================================================================
Total params: 22,610,979
Trainable params: 22,602,915
Non-trainable params: 8,064

Could you please let me know where I did wrong? I've been tried with different things but they cannot sort things out...

Metadata

Metadata

Assignees

Labels

comp:opsOPs related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authortype:supportSupport issues

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions