Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restoring from checkpoints are broken in TF 1.13.1 #27937

Closed
princessofpillows opened this issue Apr 18, 2019 · 8 comments
Closed

Restoring from checkpoints are broken in TF 1.13.1 #27937

princessofpillows opened this issue Apr 18, 2019 · 8 comments
Assignees

Comments

@princessofpillows
Copy link

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: No
  • TensorFlow installed from (source or binary): Binary, pip3 install tensorflow-gpu
  • TensorFlow version (use command below): 1.13.1
  • Python version: 3.6
  • Bazel version (if compiling from source): No
  • GCC/Compiler version (if compiling from source): No
  • CUDA/cuDNN version: V10.0.130
  • GPU model and memory: GTX 1060m, 6GB

Describe the current behavior
I am unable to restore the weights of any of my tf.keras models ONLY when restoring from a new initialization of the model. If I change the weights then restore without reinitializing the model, it will properly restore. Furthermore, a SILENT error is being thrown when this happens, requiring me to print the status of the restore to see it.

Describe the expected behavior
The weights should restore and not run into an error. And if an error would occur, it should be logged without me having to print it myself.

Code to reproduce the issue

import os
import tensorflow as tf
import numpy as np

# Disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.ERROR)
tf.enable_eager_execution()

# Create model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty before training?", model.weights == [])

# Create optim, checkpoint
optimizer = tf.train.AdamOptimizer(0.001)
checkpoint = tf.train.Checkpoint(model=model)

# Make fake data
img = np.random.uniform(0, 255, (32, 32, 3)).astype(np.float32)
truth = np.random.uniform(0, 255, (32, 32, 3)).astype(np.float32)
# Train
with tf.GradientTape() as tape:
    logits = model(img[None])
    loss = tf.losses.mean_squared_error(truth[None], logits)

# Compute/apply gradients
grads = tape.gradient(loss, model.trainable_weights)
grads_and_vars = zip(grads, model.trainable_weights)
optimizer.apply_gradients(grads_and_vars)

# Save model
checkpoint_path = './ckpt/'
checkpoint.save('./ckpt/')

# Check if weights update
print("Are weights empty after training?", model.weights == [])

# Reset model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty when resetting model?", model.weights == [])

# Update checkpoint pointer
checkpoint = tf.train.Checkpoint(model=model)
# Restore values from the checkpoint
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))

print("Are weights empty after restoring from checkpoint?", model.weights == [])
print(status)
status.assert_existing_objects_matched()
status.assert_consumed()

Other info / logs

Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? True
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7f6ac9691f98>
Traceback (most recent call last):
  File "test.py", line 56, in <module>
    status.assert_consumed()
  File "/home/jpatts/Documents/alpha-doom/env/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/util.py", line 1025, in assert_consumed
    raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
AssertionError: Unresolved object in checkpoint: attributes {
  name: "VARIABLE_VALUE"
  full_name: "sequential/conv2d/kernel"
  checkpoint_key: "model/layer-0/kernel/.ATTRIBUTES/VARIABLE_VALUE"
}
@allenlavoie
Copy link
Member

allenlavoie commented Apr 18, 2019

model hasn't created variables yet in this case:

import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
])
print(model.variables)

Prints [] (or very soon will throw an exception). So restoring at that point defers the restoration; the checkpoint guide has an explanation. If you restore and then call model, assert_consumed will pass:

import numpy as np
import tensorflow as tf

model_1 = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
])

model_1(np.random.uniform(0, 255, (1, 32, 32, 3)))
save_path = tf.train.Checkpoint(model=model_1).save("/tmp/tf_ckpts/")

model_2 = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
])

restore_checkpoint = tf.train.Checkpoint(model=model_2)
status = restore_checkpoint.restore(save_path)
#status.assert_consumed()  # Fails! model_2.variables is empty
model_2(np.random.uniform(0, 255, (1, 32, 32, 3)))
status.assert_consumed()  # Passes

The output of the model is the correctly restored output even though the variables weren't there when the restore() call was made.

Is that clear? Happy to hear ideas for better documentation. It's a somewhat tricky API I know, but restore-on-create is a bit of a trilemma: we either need to require input shapes to Layer construction so we can create variables immediately, or we can do deferred restoration, or we can require symbolic construction of the computation first (the TF 1.x approach) which gives us enough information to create the variables. We decided not to take the first path since it's annoying to have to specify, and we turned on eager by default so the third path isn't available (although you can optionally specify an input_shape to the first Layer in Sequential and it'll build everything right away).

@princessofpillows
Copy link
Author

princessofpillows commented Apr 18, 2019

Thanks for the quick response. I understand what you are saying, and I think this should be given a line or two of explanation in the TensorFlow Eager tutorial. Since in the tutorial a variable is used, which doesn’t require shape, this problem doesn’t occur until using it in practice. It could definitely use more visibility.

Also, given that the weights aren’t restored until model input is provided, I need to compare weights after restoration and input but before gradient step, correct? As at this point they should have taken on the restored values, so I can assert if they are equal.

@allenlavoie
Copy link
Member

Thanks, yes it sounds like the eager guide should have a blurb about deferred restoration and a reference to the checkpointing guide. The main reason it doesn't is presumably the order they were written.

On checking that values are restored, yes that makes sense to me. Something like this:

import tensorflow as tf
import numpy as np

def recreate_model_and_checkpoint():
  model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, 3, padding="same"),
    tf.keras.layers.Conv2D(3, 3, padding="same")
  ])
  return tf.train.Checkpoint(
      optimizer=tf.keras.optimizers.Adam(0.1),
      model=model)

def train_step(checkpoint):
  model = checkpoint.model
  optimizer = checkpoint.optimizer
  with tf.GradientTape() as tape:
    output = model(tf.ones([1, 32, 32, 3]))
    loss = tf.reduce_sum(output)
  variables = model.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer_weights = []
  before_train_step_weights = [v.numpy() for v in variables]
  optimizer.apply_gradients(zip(gradients, variables))
  return loss, before_train_step_weights

checkpoint_one = recreate_model_and_checkpoint()
# Just to create the variables so we have something to save
train_step(checkpoint_one)
save_path = checkpoint_one.save("/tmp/tf_ckpts/")
original_loss_1, original_variable_values_1 = train_step(checkpoint_one)
original_loss_2, original_variable_values_2 = train_step(checkpoint_one)

checkpoint_two = recreate_model_and_checkpoint()
status = checkpoint_two.restore(save_path)
new_loss_1, new_variable_values_1 = train_step(checkpoint_two)
status.assert_consumed()
new_loss_2, new_variable_values_2 = train_step(checkpoint_two)

np.testing.assert_allclose(new_loss_1.numpy(), original_loss_1.numpy())
np.testing.assert_allclose(new_loss_2.numpy(), original_loss_2.numpy())
assert len(original_variable_values_1) == len(new_variable_values_1)
for original_value, new_value in zip(original_variable_values_1, new_variable_values_1):
  np.testing.assert_allclose(original_value, new_value)
for original_value, new_value in zip(original_variable_values_2, new_variable_values_2):
  np.testing.assert_allclose(original_value, new_value)

You could also directly check the optimizer's slot variables rather than running two steps and checking both.

@allenlavoie allenlavoie self-assigned this Apr 18, 2019
@princessofpillows
Copy link
Author

princessofpillows commented Apr 20, 2019

After running into an issue where I did not know that my model was not being loaded, I think that status.assert_existing_objects_matched() and status.assert_consumed() should be automatically ran whenever restoring from a checkpoint.

As an example, in my case, my model was not loading the weights but it wasn't telling me, and the only check I was doing was that the weights weren't empty. Well I found out that the weights were just random inits, and since the loading error is not fatal, I was debugging for hours why my model was not performing properly. This is especially weird because I was reloading the model in the same file with no changes being made to it. It also passed status.assert_existing_objects_matched() and failed on status.assert_consumed(), outputting:

Traceback (most recent call last):
  File "simulator.py", line 131, in <module>
    main()
  File "simulator.py", line 128, in main
    model.predict(s0, action)
  File "simulator.py", line 117, in predict
    self.status.assert_consumed()
  File "/home/jpatts/.local/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/util.py", line 1013, in assert_consumed
    raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
AssertionError: Unresolved object in checkpoint: attributes {
  name: "VARIABLE_VALUE"
  full_name: "beta1_power"
  checkpoint_key: "optim/beta1_power/.ATTRIBUTES/VARIABLE_VALUE"
}

Looks like something to do with the optimizer, but it should tell me when this happens, because otherwise (until this experience) I would have just kept assuming it was working.

Also, the weirdest part was that some other variables, such as epoch (int), were being restored properly while this wasn't. So there is also inconsistent behaviour happening across what is being restored.
I think not having an all or none approach to restoring variables is not very intuitive.

@allenlavoie
Copy link
Member

When should assert_consumed run automatically? It will typically fail right after restore.

Can you share a reproduction for the unchanged file issue you ran into? That sounds like a bug.

@allenlavoie
Copy link
Member

Or how does this sound? We can print a warning on program exit by default if a checkpoint was partially loaded. Status objects will have an "allow_partial" which silences the warning.

@princessofpillows
Copy link
Author

I think that’s a good middle of the road solution, as it provides relevant information to the user while also not causing a fatal error, which might be undesirable for some.

As for the potential bug I’ll try and create an MVP of it later this week, as it’s in a fairly complicated system.

tensorflow-copybara pushed a commit that referenced this issue Apr 30, 2019
…ores

Will eventually (on __del__, so maybe at program shutdown) complain about values in the checkpoint which weren't used with restore-on-create. Adds an expect_partial() method to status objects to silence these warnings for the case where a partial restore was intended.

Following up on #27937

PiperOrigin-RevId: 245992963
@allenlavoie
Copy link
Member

Thanks for the feedback. We have a warning for partial checkpoint restores now (in the latest nightly), and the eager guide now mentions deferred restoration and points to the checkpoint guide.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants