Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.case doesn't preserve shape information #3334

Closed
alexatknit opened this issue Jul 15, 2016 · 11 comments
Closed

tf.case doesn't preserve shape information #3334

alexatknit opened this issue Jul 15, 2016 · 11 comments
Assignees
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower

Comments

@alexatknit
Copy link

alexatknit commented Jul 15, 2016

tf.case is a python implementation of a case statement using tf.cond, but unlike cond it doesn't preserve shape information when executing. This is because of this little snippet:

...
    # preds = [p1, p2, p3]
    # fns = [f1, f2, f3]
    # not_preds = [~p1, ~p2, ~p3]
    # and_not_preds = [True, ~p1, ~p1 & ~p2, ~p1 & ~p2 & ~p3]
    # case_preds = [p1,
    #               p2 & ~p1,
    #               p3 & ~p2 & ~p1,
    #              ~p3 & ~p2 & ~p1]

    case_preds = []
    for i, (p, and_not_p_prev) in enumerate(zip(preds, and_not_preds[:-1])):
      with ops.name_scope("case_%d" % i):
        case_preds.append(math_ops.logical_and(p, and_not_p_prev))
    with ops.name_scope("case_none_are_true"):
      case_preds.append(and_not_preds[-1])

    # Create an empty tensor, or list, with the right type and shape
    with ops.name_scope("case_create_empty"):
      dummy_value = default()
      def _correct_empty(v):
        if isinstance(v, ops.Operation):
          return no_op()
        elif v.dtype == dtypes.string:
          return array_ops.constant("")
        else:
          return array_ops.constant(v.dtype.as_numpy_dtype())

      if isinstance(dummy_value, collections.Sequence):
        dummy_type = type(dummy_value)
        empty = lambda: dummy_type(_correct_empty(v) for v in dummy_value)
      else:
        empty = lambda: _correct_empty(dummy_value)

    # case_sequence = [
    #   cond(~p3 & ~p2 & ~p1, default, empty),
    #   cond(p3 & ~p2 & ~p1, f3, lambda: case_sequence[0]),
    #   cond(p2 & ~p1, f2, lambda: case_sequence[1]),
    #   cond(p1, f1, lambda: case_sequence[2])
    # ]
    #
    # And the return value will be case_sequence[-1]
    def _build_case():
      all_fns = [fn for fn in fns]
      all_fns.append(default)
      prev_case = None
      for i, (cp, fn) in enumerate(list(zip(case_preds, all_fns))[::-1]):
        prev_case = cond(
            cp, fn,
            empty if i == 0 else lambda: prev_case,
            name="If_%d" % i)
      return prev_case
...

The op works by evaluating a series of predicates (including a predicate for the default value) but it starts off with an empty object. The empty object seems to be designed to pass on correct shape and type information but it fails to do so in my use case. I recommend changing this code to read:

...
    # preds = [p1, p2, p3]
    # fns = [f1, f2, f3]
    # not_preds = [~p1, ~p2, ~p3]
    # and_not_preds = [True, ~p1, ~p1 & ~p2, ~p1 & ~p2 & ~p3]
    # case_preds = [p1,
    #               p2 & ~p1,
    #               p3 & ~p2 & ~p1]

    case_preds = []
    for i, (p, and_not_p_prev) in enumerate(zip(preds, and_not_preds[:-1])):
      with ops.name_scope("case_%d" % i):
        case_preds.append(math_ops.logical_and(p, and_not_p_prev))

    # case_sequence = [
    #   cond(p3 & ~p2 & ~p1, f3, default),
    #   cond(p2 & ~p1, f2, lambda: case_sequence[0]),
    #   cond(p1, f1, lambda: case_sequence[1])
    # ]
    #
    # And the return value will be case_sequence[-1]
    def _build_case():
      all_fns = [fn for fn in fns]
      prev_case = None
      for i, (cp, fn) in enumerate(list(zip(case_preds, all_fns))[::-1]):
        prev_case = cond(
            cp, fn,
            default if prev_case is None else lambda: prev_case,
            name="If_%d" % i)
      return prev_case
...

This removes the need not only for creating a dummy empty op, but also removes the need to create a separate predicate for the default op, simplifying the whole op by about 18 lines of code.

@concretevitamin
Copy link
Contributor

Assigning to @yuanbyu to take a look (although he's out of office currently).

@concretevitamin concretevitamin added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jul 18, 2016
@yuanbyu yuanbyu assigned ebrevdo and unassigned yuanbyu Jul 19, 2016
@michaelisard michaelisard removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jul 25, 2016
@ebrevdo
Copy link
Contributor

ebrevdo commented Aug 3, 2016

This is what case originally looked like (look at the git history for the file). Unfortunately, this has the bad side effect of always executing at least two nodes (default and the one that evaluated to true). Even if the default node was not actually returned, it was still executed and as a result may have had possible side effects. Feel free to change the code as you suggested and you'll see a unit test in control_flow_ops start failing.

The real solution is to factor out cond() into a sequence of internal _if_then; which can be used for case as well.

@ebrevdo
Copy link
Contributor

ebrevdo commented Aug 10, 2016

@yuanbyu another reason to split up cond/case into if_then type statements?

@ebrevdo
Copy link
Contributor

ebrevdo commented Sep 22, 2016

This is on the backburner for a bit; sorry. No updates on when it'll be done.

@carlthome
Copy link
Contributor

ETA on fix?

I'm currently nestling tf.cond but it looks a little silly.

@ebrevdo
Copy link
Contributor

ebrevdo commented Jan 23, 2017

The fix for this one is more involved than originally anticipated. Will update the issue with an ETA once we have an initial fix.

@aselle aselle added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jan 27, 2017
@gunan
Copy link
Contributor

gunan commented Jun 16, 2017

Ping! Is there work going on to fix this issue?
Should we reassign this?

@tensorflowbutler
Copy link
Member

It has been 14 days with no activity and this issue has an assignee.Please update the label and/or status accordingly.

@tensorflowbutler
Copy link
Member

Nagging Assigneee: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@tensorflowbutler
Copy link
Member

Nagging Assignee: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@itsmeolivia
Copy link
Contributor

Automatically closing due to lack of recent activity. Please update the issue when new information becomes available, and we will reopen the issue. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower
Projects
None yet
Development

No branches or pull requests

10 participants