Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorFlow 2.0 - ValueError: tf.function-decorated #36574

Closed
arjun-majumdar opened this issue Feb 8, 2020 · 14 comments
Closed

TensorFlow 2.0 - ValueError: tf.function-decorated #36574

arjun-majumdar opened this issue Feb 8, 2020 · 14 comments
Assignees
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug

Comments

@arjun-majumdar
Copy link

Hello, I have a code (for MNIST dataset) in which I am doing the following steps:

  1. Train model
  2. Prune model (using "tensorflow_model_optimization" for say p%)
  3. Create a mask of pruned model, so that the sparsity is maintained in subsequent steps
  4. Reset weights of non-pruned model to random initialized weights when model was initialized

I do the following steps iteratively 'n' times.

The code can be found in:
https://github.com/arjun-majumdar/tensorflow_codes/blob/master/Recreating_Error.ipynb

For retraining a pruned model, I use 'GradientTape' along with mask. Now, the first time the model is trained using train_one_step() and test_step() functions which are @tf.function annotated functions, things work fine. But when I try to use them again (in cell 76 of Jupyter notebook), it gives me the error:


ValueError Traceback (most recent call last)
in
13 for x, y in train_dataset:
14 # train_step(x, y)
---> 15 train_one_step(model_gt_stripped, mask_model_stripped, optimizer, x, y)
16
17 for x_t, y_t in test_dataset:

~/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in call(self, *args, **kwds)
455
456 tracing_count = self._get_tracing_count()
--> 457 result = self._call(*args, **kwds)
458 if tracing_count == self._get_tracing_count():
459 self._call_counter.called_without_tracing()

~/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
485 # In this case we have created variables on the first call, so we run the
486 # defunned version which is guaranteed to never create variables.
--> 487 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
488 elif self._stateful_fn is not None:
489 # Release the lock early so that multiple threads can perform the call

~/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in call(self, *args, **kwargs)
1820 def call(self, *args, **kwargs):
1821 """Calls a graph function specialized to the inputs."""
-> 1822 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
1823 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
1824

~/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2148 graph_function = self._function_cache.primary.get(cache_key, None)
2149 if graph_function is None:
-> 2150 graph_function = self._create_graph_function(args, kwargs)
2151 self._function_cache.primary[cache_key] = graph_function
2152 return graph_function, args, kwargs

~/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2039 arg_names=arg_names,
2040 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2041 capture_by_value=self._capture_by_value),
2042 self._function_attributes,
2043 # Tell the ConcreteFunction to clean up its graph once it goes out of

~/.local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
913 converted_func)
914
--> 915 func_outputs = python_func(*func_args, **func_kwargs)
916
917 # invariant: func_outputs contains only Tensors, CompositeTensors,

~/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
356 # wrapped allows AutoGraph to swap in a converted function. We give
357 # the function a weak reference to itself to avoid a reference cycle.
--> 358 return weak_wrapped_fn().wrapped(*args, **kwds)
359 weak_wrapped_fn = weakref.ref(wrapped_fn)
360

~/.local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
903 except Exception as e: # pylint:disable=broad-except
904 if hasattr(e, "ag_error_metadata"):
--> 905 raise e.ag_error_metadata.to_exception(e)
906 else:
907 raise

ValueError: in converted code:

<ipython-input-44-d0ca499a4063>:29 train_one_step  *
    optimizer.apply_gradients(zip(grad_mask_mul, model.trainable_variables))
/home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:435 apply_gradients
    self._create_slots(var_list)
/home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/adam.py:146 _create_slots
    self.add_slot(var, 'm')
/home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:587 add_slot
    initial_value=initial_value)
/home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:260 __call__
    return cls._variable_v2_call(*args, **kwargs)
/home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call
    shape=shape)
/home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:65 getter
    return captured_getter(captured_previous, **kwargs)
/home/majumdar/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py:413 invalid_creator_scope
    "tf.function-decorated function tried to create "

ValueError: tf.function-decorated function tried to create variables on non-first call.

The only way of avoiding this "ValueError" is by rerunning the train_one_step() and test_step() @tf.function annotated functions!

Why is this happening?

Thanks!

@oanush oanush assigned oanush and unassigned amahendrakar Feb 10, 2020
@oanush oanush added comp:autograph Autograph related issues type:support Support issues TF 2.0 Issues relating to TensorFlow 2.0 labels Feb 10, 2020
@oanush
Copy link

oanush commented Feb 10, 2020

@arjun-majumdar ,
Can you please take a look at the link of Stack-overflow and let me know if it helps.Thanks!

@oanush oanush added the stat:awaiting response Status - Awaiting response from author label Feb 10, 2020
@arjun-majumdar
Copy link
Author

Going through the link and the related tutorial.

@arjun-majumdar
Copy link
Author

arjun-majumdar commented Feb 13, 2020

Hello, I have the following issue which doesn't make sense. The code that I have is as follows:

@tf.function
def train_step(model, mask_model, optimizer, x, y):
    with tf.GradientTape() as tape:
        y_pred = model(x)
        loss = loss_fn(y, y_pred)
    
    grads = tape.gradient(loss, model.trainable_variables)
    
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(y, y_pred)
    
    return None


@tf.function
def test_step(model, optimizer, data, labels):
    """
    Function to test model performance
    on testing dataset
    """
    
    predictions = model(data)
    t_loss = loss_fn(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)

    return None


# User input parameters for Early Stopping in manual implementation-
minimum_delta = 0.001
patience = 3

best_val_loss = 1
loc_patience = 0


# Initialize a neural network model-
model_gt = pruned_nn(pruning_params_unpruned)

# Strip model of pruning layers-
model_gt_stripped = sparsity.strip_pruning(model_gt)

for i in range(1, 6):
    
    print("\n\n\nOuter loop: {0}\n\n".format(i))
    
    # Initialize parameters for Early Stopping manual implementation-
    best_val_loss = 1
    loc_patience = 0
    
    for epoch in range(num_epochs):
    
        if loc_patience >= patience:
            print("\n'EarlyStopping' called!\n")
            break
        
        # Reset the metrics at the start of the next epoch
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()
        
        '''
        # Initialize 'grad_mask_mul' list-
        grad_mask_mul = []
    
        # Initialize all values to one-
        for wts in mask_model_stripped.trainable_weights:
            grad_mask_mul.append(wts.assign(tf.ones_like(input = wts,dtype = tf.float32)))
    
        # Convert from Python list to tf.Tensor-
        grad_mask_mul = tf.convert_to_tensor(grad_mask_mul, dtype=tf.float32)
    
        print("type(grad_mask_mul): {0}".format(type(grad_mask_mul)))
        '''
    
        for x, y in train_dataset:
            train_step(model_gt_stripped, mask_model_stripped, optimizer, x, y)
            # train_one_step(model_gt_stripped, mask_model, optimizer, x, y, grad_mask_mul)


        for x_t, y_t in test_dataset:
            # test_step(x_t, y_t)
            test_step(model_gt_stripped, optimizer, x_t, y_t)

        template = 'Epoch {0}, Loss: {1:.4f}, Accuracy: {2:.4f}, Test Loss: {3:.4f}, Test Accuracy: {4:4f}'        
        
        print(template.format(epoch + 1, 
                              train_loss.result(), train_accuracy.result()*100,
                              test_loss.result(), test_accuracy.result()*100))
    
        # Count number of non-zero parameters in each layer and in total-
        # print("layer-wise manner model, number of nonzero parameters in each layer are: \n")

        model_sum_params = 0
    
        for layer in model_gt_stripped.trainable_weights:
            # print(tf.math.count_nonzero(layer, axis = None).numpy())
            model_sum_params += tf.math.count_nonzero(layer, axis = None).numpy()
    
        print("Total number of trainable parameters = {0}\n".format(model_sum_params))

    
        # Code for manual Early Stopping:
        if np.abs(test_loss.result() < best_val_loss) >= minimum_delta:
            # update 'best_val_loss' variable to lowest loss encountered so far-
            best_val_loss = test_loss.result()
        
            # reset 'loc_patience' variable-
            loc_patience = 0
        
        else:  # there is no improvement in monitored metric 'val_loss'
            loc_patience += 1  # number of epochs without any improvement

If, I re-execute the for i in range(1, 6): block of code again, why do I get the "ValueError: tf.function-decorated function tried to create variables on non-first call." ?

What I am not getting is that "train_step()" and "test_step()" "tf.function" annotated functions are already traced and the AutoGraphs (or, tf.Graph object) are created for it. Also, the parameters being provided to these functions are not changing. Is the "ValueError" happening due to the line:

grads = tape.gradient(loss, model.trainable_variables) within "train_step()" function

and

predictions = model(data) within "test_step()" function

Since, these are the only two lines within the two functions which are creating a variable, but then again, the gradients with respect to the parameters and the model's predictions will always be made within the "tf.function" annotated functions. You cannot pass such values as parameters to the function(s).

Thanks

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Feb 21, 2020
@Saduf2019 Saduf2019 assigned Saduf2019 and unassigned oanush Feb 26, 2020
@Saduf2019
Copy link
Contributor

@arjun-majumdar
please provide us with simple indented stand alone code with all dependencies, for us to replicate it in our environment and analyse the issue faced by you.
please find the gist of the code shared by you and error faced.

@arjun-majumdar
Copy link
Author

@Saduf2019 Saduf2019 assigned gowthamkpr and unassigned Saduf2019 Mar 9, 2020
@gowthamkpr gowthamkpr removed their assignment Mar 30, 2020
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Mar 30, 2020
@nikitamaia
Copy link
Member

@arjun-majumdar have you seen this related issue #27120 ?

This comment might help you. Try creating a wrapper function for your train_one_step() function and then call separately when you train your different models.

@arjun-majumdar
Copy link
Author

@nikitamaia Let me have a look and get back to you. Thanks!

@nikitamaia
Copy link
Member

Hi @arjun-majumdar were you able to get your code working by creating a wrapper function?

@arjun-majumdar
Copy link
Author

Hello @nikitamaia the TF annotated funtion after wrapping function works, however, the performance benefits gained by '@tf.function' annotation is lost.

Also, why should the graph generated be retraced if the neural network architecture isn't changing and/or the data type of the tensors are also not changing.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 25, 2020
@nikitamaia
Copy link
Member

Sorry for the late response here. Wanted to provide a quick update that this does seem to be a bug. I know the workaround of wrapping the function is not ideal, but I can update this thread when there's progress on this.

@arjun-majumdar
Copy link
Author

@nikitamaia Thanks for the reply. I will be happy to receive an update if there is a fix/solution found for this bug.

@nikitamaia nikitamaia added comp:keras Keras related issues and removed comp:autograph Autograph related issues labels Aug 24, 2020
@nikitamaia
Copy link
Member

The Better Performance With tf.function guide has now been updated to provide more detail about this error and about using using tf.variables with multiple Keras models or optimizers. Closing this issue now since there is a workaround provided in the docs.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@sjtusmartboy
Copy link

@nikitamaia I'm afraid that isn't the workround, it's just the rule, not fitting the user's demand

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug
Projects
None yet
Development

No branches or pull requests

8 participants