Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: No gradients provided for any variable: ['embedding_2/embeddings:0', 'bidirectional_4/forward_lstm_4/lstm_cell_13/kernel:0', 'bidirectional_4/forward_lstm_4/lstm_cell_13/recurrent_kernel:0', 'bidirectional_4/forward_lstm_4/lstm_cell_13/bias:0', 'bidirectional_4/backward_lstm_4/lstm_cell_14/kernel:0', 'bidirectional_4/backward_lstm_4/lstm_cell_14/recurrent_kernel:0', 'bidirectional_4/backward_lstm_4/lstm_cell_14/bias:0', 'bidirectional_5/forward_lstm_5/lstm_cell_16/kernel:0', 'bidirectional_5/forward_lstm_5/lstm_cell_16/recurrent_kernel:0', 'bidirectional_5/forward_lstm_5/lstm_cell_16/bias:0', 'bidirectional_5/backward_lstm_5/lstm_cell_17/kernel:0', 'bidirectional_5/backward_lstm_5/lstm_cell_17/recurrent_kernel:0', 'bidirectional_5/backward_lstm_5/lstm_cell_17/bias:0', 'dense_2/kernel:0', 'dense_2/bias:0']. #43258

Closed
RavitejaBadugu opened this issue Sep 16, 2020 · 7 comments
Assignees
Labels
comp:keras Keras related issues type:others issues not falling in bug, perfromance, support, build and install or feature

Comments

@RavitejaBadugu
Copy link

RavitejaBadugu commented Sep 16, 2020

I am training model in colab. The model I'm trying to build is siamese network for text similarity.
the loss function which I used finds internally the negative sample.
Explaination of loss function::
kindly, watch this two small videos where the lecturer explains about the loss function I used::

https://www.coursera.org/lecture/sequence-models-in-nlp/computing-the-cost-i-T4Ylj
https://www.coursera.org/lecture/sequence-models-in-nlp/computing-the-cost-ii-qXOjN
the dataset can be found here::
https://www.kaggle.com/c/quora-question-pairs/data
my colab notebook::
https://colab.research.google.com/drive/1NCUxSS9fiuLpPd2hOSxKv5QKH2IUuhX2?usp=sharing

sub_model=tf.keras.models.Sequential([Embedding(vocab_size,300,input_length=79), 
Bidirectional(LSTM(79,return_sequences=True)),
Bidirectional(LSTM(79,return_sequences=True)),
tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense(units=158)])
ins1=Input((79,),name='input1')
ins2=Input((79,),name='input2')
x1=sub_model(ins1)
x2=sub_model(ins2)
norm1=tf.keras.layers.Layer(lambda x: tf.math.l2_normalize(x,axis=1))(x1)
norm2=tf.keras.layers.Layer(lambda x: tf.math.l2_normalize(x,axis=1))(x2)

model=Model([ins1,ins2],[norm1,norm2])

the loss function used is


def get_tripletloss(y_pred1,y_pred2):
    y_pred=tf.matmul(y_pred1,y_pred2,transpose_b=True) ##getting y_pred of (batch,batch)
    batch=y_pred.get_shape().as_list()[0] ##getting batch_size
    alpha_matrix=tf.cast(tf.reshape(tf.repeat(0.2,repeats=batch),shape=(batch,1)),dtype=tf.float32) #making alpha matrix of 0.2's
    diag_part=tf.cast(tf.reshape(tf.linalg.diag_part(y_pred),shape=(batch,1)),dtype=tf.float32) ##taking diag_part
    diagonal_matrix=tf.cast(tf.linalg.diag(tf.linalg.diag_part(y_pred)),dtype=tf.float32) ## making as diagonal_matrix
    sim_an=tf.reshape(tf.reduce_mean(tf.cast(y_pred,dtype=tf.float32)-diagonal_matrix,axis=1),shape=(batch,1)) ## getting only off-diagonal
    sim_an=tf.cast(sim_an,dtype=tf.float32)-diag_part+alpha_matrix ## getting sim_an-sim_ap+alpha
    sim_an=tf.maximum(sim_an,tf.cast(tf.zeros((batch,1)),dtype=tf.float32)) # getting max(loss,0)
    loss1=tf.keras.backend.mean(sim_an) ##final_loss1
    ##########
    y_pred=tf.where(y_pred<diag_part,y_pred,tf.cast(0.000001,dtype=tf.float32))##made to small number where off-diagonal elements are getter than diagonal and also diagonal elements.
    sim_an2=tf.reshape(tf.reduce_max(y_pred,axis=1),shape=(batch,1)) ##getting max value(closet_negative) 
    loss2=tf.cast(sim_an2,dtype=tf.float32)-diag_part+alpha_matrix ##getting sim_an-sim_ap+alpha
    loss2=tf.maximum(loss2,tf.cast(tf.zeros((batch,1)),dtype=tf.float32)) ## max(loss2,0)
    loss2=tf.keras.backend.mean(loss2)
    loss=loss1+loss2
    return loss

and the training loop is

dataset=tf.data.Dataset.from_tensor_slices((seq1,seq2))
dataset=dataset.shuffle(149263)
dataset=dataset.batch(29,drop_remainder=True)
dataset=dataset.prefetch(tf.data.experimental.AUTOTUNE)

@tf.function
def get_grads(v1,v2):
    with tf.GradientTape() as tape:
        loss=get_tripletloss(v1,v2)
        grads=tape.gradient(loss,model.trainable_variables)
    return loss,grads
loss_per_batch=[]
loss_per_epoch=[]
optimizer=tf.keras.optimizers.Adam()
for j in range(2):
    training_batches=seq1.shape[0]//29
    for i in range(training_batches):
        data=next(iter(dataset))
        v1,v2=model(data)
        loss,grads=get_grads(v1,v2)
        loss_per_batch.append(loss)
        optimizer.apply_gradients(zip(grads,model.trainable_variables))
    loss_per_epoch.append(np.mean(loss_per_batch))
    loss_per_batch=[]
    print(f"finsihed {j+1} epoch got loss of {loss_per_epoch[j]}")

the error is ::

ValueError                                Traceback (most recent call last)
<ipython-input-98-24aaeac0e75b> in <module>()
      9         loss,grads=get_grads(v1,v2)
     10         loss_per_batch.append(loss)
---> 11         optimizer.apply_gradients(zip(grads,model.trainable_variables))
     12     loss_per_epoch.append(np.mean(loss_per_batch))
     13     loss_per_batch=[]

1 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py in apply_gradients(self, grads_and_vars, name, experimental_aggregate_gradients)
    511       ValueError: If none of the variables have gradients.
    512     """
--> 513     grads_and_vars = _filter_grads(grads_and_vars)
    514     var_list = [v for (_, v) in grads_and_vars]
    515 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py in _filter_grads(grads_and_vars)
   1269   if not filtered:
   1270     raise ValueError("No gradients provided for any variable: %s." %
-> 1271                      ([v.name for _, v in grads_and_vars],))
   1272   if vars_with_empty_grads:
   1273     logging.warning(

ValueError: No gradients provided for any variable: ['embedding_2/embeddings:0', 'bidirectional_4/forward_lstm_4/lstm_cell_13/kernel:0', 'bidirectional_4/forward_lstm_4/lstm_cell_13/recurrent_kernel:0', 'bidirectional_4/forward_lstm_4/lstm_cell_13/bias:0', 'bidirectional_4/backward_lstm_4/lstm_cell_14/kernel:0', 'bidirectional_4/backward_lstm_4/lstm_cell_14/recurrent_kernel:0', 'bidirectional_4/backward_lstm_4/lstm_cell_14/bias:0', 'bidirectional_5/forward_lstm_5/lstm_cell_16/kernel:0', 'bidirectional_5/forward_lstm_5/lstm_cell_16/recurrent_kernel:0', 'bidirectional_5/forward_lstm_5/lstm_cell_16/bias:0', 'bidirectional_5/backward_lstm_5/lstm_cell_17/kernel:0', 'bidirectional_5/backward_lstm_5/lstm_cell_17/recurrent_kernel:0', 'bidirectional_5/backward_lstm_5/lstm_cell_17/bias:0', 'dense_2/kernel:0', 'dense_2/bias:0'].
@RavitejaBadugu RavitejaBadugu added the type:others issues not falling in bug, perfromance, support, build and install or feature label Sep 16, 2020
@Saduf2019
Copy link
Contributor

@RavitejaBadugu
The colab note book shared, is empty, please update with complete code and error.
Please refer to these issues with same error: #42038 link,

@Saduf2019 Saduf2019 added the stat:awaiting response Status - Awaiting response from author label Sep 16, 2020
@RavitejaBadugu
Copy link
Author

RavitejaBadugu commented Sep 17, 2020

@Saduf2019
that's not a gist. just copy the link and paste it in a new tab. It's not empty.

@Saduf2019 Saduf2019 added comp:keras Keras related issues and removed stat:awaiting response Status - Awaiting response from author labels Sep 18, 2020
@Saduf2019 Saduf2019 assigned gowthamkpr and unassigned Saduf2019 Sep 21, 2020
@RavitejaBadugu
Copy link
Author

RavitejaBadugu commented Sep 24, 2020

How can we create a loss function which doesn't take y_true. But, My loss function which I showed in the issue calculates the loss without y_true. (for information check the links)In documentation of tf, I found that when we are writing a custom loss function, then
we need to define as
def loss(y_true,t_pred, any smoothing_parameter):

So, the function is expecting us to give y_true, the logic in my loss function is correct. My doubt is how to define without y_true. May be it is the one causing the problem. I think! but don't know. Please! clarify this doubt @gowthamkpr

@gowthamkpr gowthamkpr assigned ymodak and unassigned gowthamkpr Nov 2, 2020
@RavitejaBadugu
Copy link
Author

I got the solution. I did a small mistake in my code.

@KrishPro
Copy link

KrishPro commented Feb 7, 2021

@RavitejaBadugu

I got the solution. I did a small mistake in my code.

Can you tell me that mistake because i'm also getting same mistake

@yhetman
Copy link

yhetman commented May 12, 2021

@RavitejaBadugu
Also have same mistake.. i would appreciate if you tell how you solved the mistake

@RavitejaBadugu
Copy link
Author

In custom training loop within tape we need to mention model(data) if we mention it before I got the error.

loss_per_epoch=[]
optimizer=tf.keras.optimizers.Adam()
for j in range(2):
    training_batches=x_train.shape[0]//12
    for i in tqdm(range(training_batches)):
        data=next(iter(dataset))
        outputs=model(data)## it is mentioned outside the tape
        with tf.GradientTape() as tape:
          losses=loss(outputs.get('out1'),outputs.get('out2'),0.3,-2.0)
        grads=tape.gradient(losses,model.trainable_variables)
        loss_per_batch.append(loss)
        optimizer.apply_gradients(zip(grads,model.trainable_variables))
    loss_per_epoch.append(np.mean(loss_per_batch))
    loss_per_batch=[]
    print(f"finsihed {j+1} epoch got loss of {loss_per_epoch[j]}")

getting error as

ValueError: No gradients provided for any variable: ['tf_roberta_model/roberta/encoder/layer_._0/attention/self/query/kernel:0', 'tf_roberta_model/roberta/encoder/layer_._0/attention/self/query/bias:0', 'tf_roberta_model/roberta/encoder/layer_._0/attention/self/key/kernel:0', 'tf_roberta_model/roberta/encoder/layer_._0/attention/self/key/bias:0', 'tf_roberta_model/roberta/encoder/layer_._0/attention/self/value/kernel:0', 'tf_roberta_model/roberta/encoder/layer_._0/attention/self/value/bias:0', 'tf_roberta_model/roberta/encoder/layer_._0/attention/o....

but when I mention the model within the tape the error is gone

loss_per_epoch=[]
optimizer=tf.keras.optimizers.Adam()
for j in range(2):
    training_batches=x_train.shape[0]//12
    for i in tqdm(range(training_batches)):
        data=next(iter(dataset))
        with tf.GradientTape() as tape:
          outputs=model(data)###################mentioned here
          losses=loss(outputs.get('out1'),outputs.get('out2'),0.3,-2.0)
        grads=tape.gradient(losses,model.trainable_variables)
        loss_per_batch.append(loss)
        optimizer.apply_gradients(zip(grads,model.trainable_variables))
    loss_per_epoch.append(np.mean(loss_per_batch))
    loss_per_batch=[]
    print(f"finsihed {j+1} epoch got loss of {loss_per_epoch[j]}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues type:others issues not falling in bug, perfromance, support, build and install or feature
Projects
None yet
Development

No branches or pull requests

6 participants