Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MutableHashTable and DenseHashTable missing name parameter #29306

Merged
merged 1 commit into from Jun 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
119 changes: 119 additions & 0 deletions tensorflow/python/kernel_tests/lookup_ops_test.py
Expand Up @@ -1314,6 +1314,72 @@ def testSaveRestore(self):
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, -1, 2, 3], output.eval())

@test_util.run_v1_only("Saver V1 only")
def testSaveRestoreOnlyTable(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

with self.session(graph=ops.Graph()) as sess:
default_value = -1
empty_key = 0
deleted_key = -1
keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
table = lookup_ops.DenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=32)

save = saver.Saver([table])

self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(4, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))

keys2 = constant_op.constant([12, 15], dtypes.int64)
table.remove(keys2).run()
self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))

val = save.save(sess, save_path)
self.assertIsInstance(val, six.string_types)
self.assertEqual(save_path, val)

with self.session(graph=ops.Graph()) as sess:
table = lookup_ops.DenseHashTable(
dtypes.int64,
dtypes.int64,
default_value=default_value,
empty_key=empty_key,
deleted_key=deleted_key,
name="t1",
checkpoint=True,
initial_num_buckets=64)
table.insert(
constant_op.constant([11, 14], dtypes.int64),
constant_op.constant([12, 24], dtypes.int64)).run()
self.assertAllEqual(2, table.size().eval())
self.assertAllEqual(64, len(table.export()[0].eval()))

save = saver.Saver([table])

# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)

self.assertAllEqual(3, table.size().eval())
self.assertAllEqual(32, len(table.export()[0].eval()))

input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, -1, 2, 3], output.eval())


@test_util.run_in_graph_and_eager_modes
def testObjectSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
Expand Down Expand Up @@ -2636,6 +2702,59 @@ def testSaveRestore(self):
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))

@test_util.run_v1_only("SaverV1")
def testSaveRestoreOnlyTable(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

with self.session(graph=ops.Graph()) as sess:
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(20.0, name="v1")

default_val = -1
keys = constant_op.constant(["b", "c", "d"], dtypes.string)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableHashTable(
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)

save = saver.Saver([table])
self.evaluate(variables.global_variables_initializer())

# Check that the parameter nodes have been initialized.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))

self.assertAllEqual(0, self.evaluate(table.size()))
self.evaluate(table.insert(keys, values))
self.assertAllEqual(3, self.evaluate(table.size()))

val = save.save(sess, save_path)
self.assertIsInstance(val, six.string_types)
self.assertEqual(save_path, val)

with self.session(graph=ops.Graph()) as sess:
default_val = -1
table = lookup_ops.MutableHashTable(
dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
self.evaluate(
table.insert(
constant_op.constant(["a", "c"], dtypes.string),
constant_op.constant([12, 24], dtypes.int64)))
self.assertAllEqual(2, self.evaluate(table.size()))

save = saver.Saver([table])

# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)
# Check that the parameter nodes have been restored.

self.assertAllEqual(3, self.evaluate(table.size()))

input_string = constant_op.constant(["a", "b", "c", "d", "e"],
dtypes.string)
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))

@test_util.run_in_graph_and_eager_modes
def testObjectSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/python/ops/lookup_ops.py
Expand Up @@ -1756,7 +1756,8 @@ def export(self, name=None):

def _gather_saveables_for_checkpoint(self):
"""For object-based checkpointing."""
return {"table": functools.partial(MutableHashTable._Saveable, table=self)}
return {"table": functools.partial(
MutableHashTable._Saveable, table=self, name=self._name)}

class _Saveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for MutableHashTable."""
Expand Down Expand Up @@ -2048,7 +2049,8 @@ def export(self, name=None):

def _gather_saveables_for_checkpoint(self):
"""For object-based checkpointing."""
return {"table": functools.partial(DenseHashTable._Saveable, table=self)}
return {"table": functools.partial(
DenseHashTable._Saveable, table=self, name=self._name)}

class _Saveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for DenseHashTable."""
Expand Down