-
Notifications
You must be signed in to change notification settings - Fork 616
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Google Colab
TensorFlow version and how it was installed (source or binary): tf-nightly 2.2.0-dev20200408 (pip install) and 2.1 stable version
TensorFlow-Addons version and how it was installed (source or binary):
Python version:
Is GPU used? (yes/no): TPU
Describe the bug
A clear and concise description of what the bug is.
Code to reproduce the issue
optimizer = optimization.create_optimizer(
init_lr=INIT_LR,
num_train_steps=NB_BATCHES_TRAIN, # per epochs
num_warmup_steps=WARMUP_STEPS)
def squad_loss_fn(labels, model_outputs):
start_positions = labels['start_positions']
end_positions = labels['end_positions']
start_logits, end_logits = model_outputs
start_loss = tf.keras.backend.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
end_loss = tf.keras.backend.sparse_categorical_crossentropy(
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
return total_loss
train_loss = tf.keras.metrics.Mean(name="train_loss")
bert_squad.compile(optimizer,
squad_loss_fn)
# Training loop
NB_EPOCHS = 3
for epoch in range(NB_EPOCHS):
print("Start of epoch {}".format(epoch+1))
start = time.time()
train_loss.reset_states()
for (batch, (inputs, targets)) in enumerate(train_dataset_light):
with tf.GradientTape() as tape:
model_outputs = bert_squad(inputs)
loss = squad_loss_fn(targets, model_outputs)
grads = tape.gradient(loss, bert_squad.trainable_variables)
optimizer.apply_gradients(zip(grads, bert_squad.trainable_variables))
#optimizer.apply_gradients(zip(grads, bert_squad.trainable_variables), name=None, all_reduce_sum_gradients=True))
train_loss(loss)
Other info / logs
same issue here : Missing argument in apply_gradients() in AdamW optimizer #1267
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working