Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: tensorflow.placeholder shape does not serialize with protobuf #9103

Closed
figmentc opened this issue Apr 10, 2017 · 6 comments
Closed

BUG: tensorflow.placeholder shape does not serialize with protobuf #9103

figmentc opened this issue Apr 10, 2017 · 6 comments
Assignees
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug

Comments

@figmentc
Copy link

Profobuf serialization(json)
{
"attr": {
"dtype": {
"type": "DT_FLOAT"
},
"shape": {
"shape": {}
}
},
"name": "x",
"op": "Placeholder"
},

Tensorflow code
x = tf.placeholder(tf.float32, shape=None, name="x")

@figmentc figmentc changed the title tensorflow.placeholder shape does not serialize with protobuf BUG: tensorflow.placeholder shape does not serialize with protobuf Apr 10, 2017
@asimshankar asimshankar added the stat:awaiting response Status - Awaiting response from author label Apr 10, 2017
@asimshankar
Copy link
Contributor

Could you elaborate on the bug here? There is a known issue with the Placeholder op where it cannot distinguish between an unknown and a scalar shape, but it does serialize all other shapes correctly.

There is some work underway to figure out if that bug can be fixed without requiring the PlaceholderV2 operation, but all other shapes should be fine regardless.

Could you elaborate on your concern here?

@figmentc
Copy link
Author

figmentc commented Apr 10, 2017

Sorry, I copied the wrong line from Python as I was testing.
When a placeholder of shape [None, 784] was serialized, the corresponding element in the profobuf json serialization does not contain a shape attribute.

This is the python code:

    x = tf.placeholder(tf.float32, shape=[None, 784], name="x")
    y_ = tf.placeholder(tf.float32, shape=[None, 10], name="y_")
    with tf.name_scope("first_layer"):

        W = tf.Variable(tf.zeros([784,10]), name="W")
        b = tf.Variable(tf.zeros([10]), name="b")
    # Output
        y = tf.matmul(x,W) + b

    with tf.name_scope("softmax_layer"):
    # Loss Function
        softmax = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)
    with tf.name_scope("error_check"):
        cross_entropy = tf.reduce_mean(softmax)

    with tf.name_scope("accuracy_check"):
    #Accuracy Calc
        correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    sess.run(tf.global_variables_initializer())

    outfile_txt = json_format.MessageToJson(sess.graph_def)
    outfile = open("outfile.json", 'w')
    outfile.write(outfile_txt)

Select elements from the output json file:

     {
      "attr": {
        "dtype": {
          "type": "DT_FLOAT"
        },
        "shape": {
          "shape": {}
        }
      },
      "name": "x",
      "op": "Placeholder"
    },
{
      "attr": {
        "shape": {
          "shape": {
            "dim": [
              {
                "size": "784"
              },
              {
                "size": "10"
              }
            ]
          }
        },
        "shared_name": {
          "s": ""
        },
        "container": {
          "s": ""
        },
        "dtype": {
          "type": "DT_FLOAT"
        }
      },
      "name": "first_layer/W",
      "op": "VariableV2"
    }

Let me know if this is because of what you said earlier. Im using tensorflow-gpu installed from pip3 on windows.

@aselle aselle removed the stat:awaiting response Status - Awaiting response from author label Apr 10, 2017
@jart
Copy link
Contributor

jart commented Apr 10, 2017

@asimshankar do you know if there's an issue tracking that work?

@asimshankar asimshankar added stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug labels Apr 10, 2017
@asimshankar
Copy link
Contributor

@karpkarp : Thanks for the sample code. It seems that if any of the dimensions are unknown is when we end up with an empty shape in the GraphDef, which is broader than the problem PlaceholderV2 is going to address.

I'll dig in a bit more.

CC @vrv

@asimshankar asimshankar self-assigned this Apr 10, 2017
@vrv
Copy link

vrv commented Apr 10, 2017

Actually I'm trying to change Placeholder itself so no new V2 is needed, but this is precisely correct. We currently lose shape information when you serialize and deserialize partially known placeholder shapes. This is fixed in V2 which I am trying to backport to v1.

@figmentc
Copy link
Author

figmentc commented Apr 10, 2017

It seems in array_ops.py, it sets a requirement for the shape to be fully defined with shape.is_fully_defined() in the placeholder function. Any particular reason for this? Does this mean that a placeholder of shape of [None, SomeNum] will not be enforced?

In any case, I removed the condition where the Placeholder shape has to be fully defined and the serialization issues are fixed. This does break placeholders with no defined shape so I added two additional function in python/framework/tensor_shape.py

python/framework/tensor_shape

 def is_partially_defined(self):
    return self._dims is not None

  def assert_is_partially_defined(self):
    if not self.is_partially_defined(self):
      raise ValueError("Shape %s is not partially defined" % self)

python\ops\array_ops.py

def placeholder(dtype, shape=None, name=None):
  shape = tensor_shape.as_shape(shape)
  if shape.is_partially_defined():
    dim_list = shape.as_list()
  else:
    dim_list = []
  ret = gen_array_ops._placeholder(
      dtype=dtype,
      shape=dim_list,
      name=name)
  ret.set_shape(shape)
  return ret
  "versions": {
    "producer": 21
  },
  "node": [
    {
      "op": "Placeholder",
      "name": "x",
      "attr": {
        "shape": {
          "shape": {
            "dim": [
              {
                "size": "-1"
              },
              {
                "size": "784"
              }
            ]
          }
        },
        "dtype": {
          "type": "DT_FLOAT"
        }
      }
    },
    {
      "op": "Placeholder",
      "name": "y_",
      "attr": {
        "shape": {
          "shape": {
            "dim": [
              {
                "size": "-1"
              },
              {
                "size": "10"
              }
            ]
          }
        },
        "dtype": {
          "type": "DT_FLOAT"
        }
      }
    },
    {
      "op": "NoOp",
      "name": "init"
    }
  ]
}

figmentc pushed a commit to figmentc/tensorflow that referenced this issue Apr 10, 2017
@drpngx drpngx closed this as completed in 24a95ae Apr 11, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug
Projects
None yet
Development

No branches or pull requests

5 participants