Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recent change in tf.vectorized_map breaks MCMC when batch_shape = 1 #1071

Open
junpenglao opened this issue Aug 30, 2020 · 3 comments
Open
Assignees

Comments

@junpenglao
Copy link
Contributor

Minimal reproducible example:

import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

tfd = tfp.distributions

dist = tfd.Normal(0., 1.)

def vectorized_logpfn(*state): 
    return tf.vectorized_map(lambda mini_state: dist.log_prob(*mini_state), state)

init = dist.sample(1)

@tf.function
def run_fn(init, burn_in):
    return tfp.mcmc.sample_chain(
        10, init, 
        num_burnin_steps=burn_in,
        kernel=tfp.mcmc.HamiltonianMonteCarlo(
            vectorized_logpfn, .1, num_leapfrog_steps=5))

run_fn(init, 10)

returns:

ValueError                                Traceback (most recent call last)
<ipython-input-17-28b7450607cd> in <module>
      7             vectorized_logpfn, .1, num_leapfrog_steps=5))
      8 
----> 9 run_fn(init, 10)

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    784     tracing_count = self._get_tracing_count()
    785     with trace.Trace(self._name) as tm:
--> 786       result = self._call(*args, **kwds)
    787       compiler = "xla" if self._experimental_compile else "nonXla"
    788       new_tracing_count = self._get_tracing_count()

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    827       # This is the first call of __call__, so we have to initialize.
    828       initializers = []
--> 829       self._initialize(args, kwds, add_initializers_to=initializers)
    830     finally:
    831       # At this point we know that the initialization is complete (or less

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    715     self._concrete_stateful_fn = (
    716         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 717             *args, **kwds))
    718 
    719     def invalid_creator_scope(*unused_args, **unused_kwds):

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2953       args, kwargs = None, None
   2954     with self._lock:
-> 2955       graph_function, _ = self._maybe_define_function(args, kwargs)
   2956     return graph_function
   2957 

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3353 
   3354           self._function_cache.missed.add(call_context_key)
-> 3355           graph_function = self._create_graph_function(args, kwargs)
   3356           self._function_cache.primary[cache_key] = graph_function
   3357 

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3198             arg_names=arg_names,
   3199             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3200             capture_by_value=self._capture_by_value),
   3201         self._function_attributes,
   3202         function_spec=self.function_spec,

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    985         _, original_func = tf_decorator.unwrap(python_func)
    986 
--> 987       func_outputs = python_func(*func_args, **func_kwargs)
    988 
    989       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    623             xla_context.Exit()
    624         else:
--> 625           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    626         return out
    627 

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    972           except Exception as e:  # pylint:disable=broad-except
    973             if hasattr(e, "ag_error_metadata"):
--> 974               raise e.ag_error_metadata.to_exception(e)
    975             else:
    976               raise

ValueError: in user code:

    <ipython-input-17-28b7450607cd>:4 run_fn  *
        10, init,
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/sample.py:374 sample_chain  **
        parallel_iterations=parallel_iterations)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/internal/util.py:464 trace_scan
        parallel_iterations=parallel_iterations)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:574 new_func
        return func(*args, **kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2499 while_loop_v2
        return_same_structure=True)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2696 while_loop
        back_prop=back_prop)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:196 while_loop
        add_control_dependencies=add_control_dependencies)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:987 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:174 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/internal/util.py:450 _body
        state = loop_fn(state, elem)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/sample.py:358 _trace_scan_fn
        parallel_iterations=parallel_iterations)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/internal/util.py:353 smart_for_loop
        parallel_iterations=parallel_iterations
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:574 new_func
        return func(*args, **kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2499 while_loop_v2
        return_same_structure=True)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2696 while_loop
        back_prop=back_prop)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:196 while_loop
        add_control_dependencies=add_control_dependencies)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:987 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:174 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/internal/util.py:351 <lambda>
        body=lambda i, *args: [i + 1] + list(body_fn(*args)),
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/sample.py:351 _seeded_one_step
        kernel.one_step(*state_and_results, **one_step_kwargs))
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/hmc.py:574 one_step
        current_state, previous_kernel_results, seed=seed)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/metropolis_hastings.py:218 one_step
        **inner_kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/hmc.py:777 one_step
        current_target_log_prob_grad_parts)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py:291 __call__
        target_grad_parts,
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:574 new_func
        return func(*args, **kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2499 while_loop_v2
        return_same_structure=True)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2696 while_loop
        back_prop=back_prop)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:196 while_loop
        add_control_dependencies=add_control_dependencies)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:987 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:180 wrapped_body
        expand_composites=True)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/util/nest.py:411 assert_same_structure
        % (str(e), str1, str2))

    ValueError: The two structures don't have the same nested structure.
    
    First structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/add:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/add_1:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/add:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/pfor/Tile:0' shape=(1,) dtype=float32>, [<tensorflow.python.framework.indexed_slices.IndexedSlices object at 0x7f37d3f45e90>]]
    
    Second structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/iter:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/add:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/Placeholder_2:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/Placeholder_4:0' shape=(1,) dtype=float32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/Placeholder_5:0' shape=(1,) dtype=float32>]]
    
    More specifically: Substructure "type=IndexedSlices str=IndexedSlices(indices=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape_1:0", shape=(1,), dtype=int32), values=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape:0", shape=(1,), dtype=float32), dense_shape=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Cast:0", shape=(1,), dtype=int32))" is a sequence, while substructure "type=Tensor str=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/Placeholder_5:0", shape=(1,), dtype=float32)" is not
    Entire first structure:
    [., [.], [.], ., [.]]
    Entire second structure:
    [., [.], [.], ., [.]]
@junpenglao
Copy link
Contributor Author

The bug seems to be introduced in tensorflow/tensorflow@c2e5944 it is discovered in pymc-devs/pymc4#317 (comment).

@davmre I am assigning this to you since you have a big more context about the change.

@davmre
Copy link
Contributor

davmre commented Sep 2, 2020

Thanks for reporting; this is a pretty weird issue. After poking for a bit I think the root of the problem is that this snippet computing the gradient of tf.gather:

@tf.function(autograph=False)
def gather_grad(x):
  with tf.GradientTape() as tape:
    tape.watch(x)
    v = tf.gather(x, 0)
  g = tape.gradient(v, x)
  return g
gather_grad(x=tf.convert_to_tensor([1.]))

returns a
<google3.third_party.tensorflow.python.framework.indexed_slices.IndexedSlices at 0x7fb5f27bc4a8> instance instead of a simple Tensor. The IndexedSlices instance is convertible to a Tensor, but its underlying representation uses two Tensors (one for a value being sliced, the other for the slice), and that screws up the HamiltonianMonteCarlo while_loop which expects to see the same Tensor structure it was initialized with.

The contribution of tensorflow/tensorflow@c2e5944 is somewhat tangential: it calls tf.gather(x, 0) for unit-batch Tensors directly, where previously the autovectorization machinery would see tf.gather(x, i) (where i is an abstract batch index) and do something more complicated that I think might end up eliding the gather altogether. The change is fine IMHO, but it seems to have triggered this complicated interaction.

I think we'll need to consult TF Core team on the most natural fix: it might make sense to change the gradient definition for tf.gather, or for while_loop to try to convert any CompositeTensors in its loop state to Tensors before giving up. I'll file a couple of bugs.

@davmre
Copy link
Contributor

davmre commented Sep 2, 2020

Actually it might make more sense to just work around this at the TFP level by calling convert_to_tensor on all gradients inside the MCMC loop. I'll follow up tomorrow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants