Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model restore support for tree and forest variables #17070

Merged
merged 5 commits into from
May 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 22 additions & 12 deletions tensorflow/contrib/tensor_forest/python/tensor_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,15 +295,15 @@ def get_epoch_variable():


# A simple container to hold the training variables for a single tree.
class TreeTrainingVariables(object):
class TreeVariables(object):
"""Stores tf.Variables for training a single random tree.

Uses tf.get_variable to get tree-specific names so that this can be used
with a tf.learn-style implementation (one that trains a model, saves it,
then relies on restoring that model to evaluate).
"""

def __init__(self, params, tree_num, training):
def __init__(self, params, tree_num, training, tree_config='', tree_stat=''):
if (not hasattr(params, 'params_proto') or
not isinstance(params.params_proto,
_params_proto.TensorForestParams)):
Expand All @@ -315,27 +315,28 @@ def __init__(self, params, tree_num, training):
# TODO(gilberth): Manually shard this to be able to fit it on
# multiple machines.
self.stats = stats_ops.fertile_stats_variable(
params, '', self.get_tree_name('stats', tree_num))
params, tree_stat, self.get_tree_name('stats', tree_num))
self.tree = model_ops.tree_variable(
params, '', self.stats, self.get_tree_name('tree', tree_num))
params, tree_config, self.stats, self.get_tree_name('tree', tree_num))

def get_tree_name(self, name, num):
return '{0}-{1}'.format(name, num)


class ForestTrainingVariables(object):
class ForestVariables(object):
"""A container for a forests training data, consisting of multiple trees.

Instantiates a TreeTrainingVariables object for each tree. We override the
Instantiates a TreeVariables object for each tree. We override the
__getitem__ and __setitem__ function so that usage looks like this:

forest_variables = ForestTrainingVariables(params)
forest_variables = ForestVariables(params)

... forest_variables.tree ...
"""

def __init__(self, params, device_assigner, training=True,
tree_variables_class=TreeTrainingVariables):
tree_variables_class=TreeVariables,
tree_configs=None, tree_stats=None):
self.variables = []
# Set up some scalar variables to run through the device assigner, then
# we can use those to colocate everything related to a tree.
Expand All @@ -347,7 +348,13 @@ def __init__(self, params, device_assigner, training=True,

for i in range(params.num_trees):
with ops.device(self.device_dummies[i].device):
self.variables.append(tree_variables_class(params, i, training))
kwargs = {}
if tree_configs is not None:
kwargs.update(dict(tree_config=tree_configs[i]))
if tree_stats is not None:
kwargs.update(dict(tree_stat=tree_stats[i]))
self.variables.append(tree_variables_class(
params, i, training, **kwargs))

def __setitem__(self, t, val):
self.variables[t] = val
Expand All @@ -361,19 +368,22 @@ class RandomForestGraphs(object):

def __init__(self,
params,
tree_configs=None,
tree_stats=None,
device_assigner=None,
variables=None,
tree_variables_class=TreeTrainingVariables,
tree_variables_class=TreeVariables,
tree_graphs=None,
training=True):
self.params = params
self.device_assigner = (
device_assigner or framework_variables.VariableDeviceChooser())
logging.info('Constructing forest with params = ')
logging.info(self.params.__dict__)
self.variables = variables or ForestTrainingVariables(
self.variables = variables or ForestVariables(
self.params, device_assigner=self.device_assigner, training=training,
tree_variables_class=tree_variables_class)
tree_variables_class=tree_variables_class,
tree_configs=tree_configs, tree_stats=tree_stats)
tree_graph_class = tree_graphs or RandomTreeGraphs
self.trees = [
tree_graph_class(self.variables[i], self.params, i)
Expand Down
45 changes: 45 additions & 0 deletions tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
from __future__ import division
from __future__ import print_function

from google.protobuf.json_format import ParseDict
from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest


Expand Down Expand Up @@ -110,6 +114,47 @@ def testInferenceConstruction(self):
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))

def testInfrenceFromRestoredModel(self):
input_data = [[-1., 0.], [-1., 2.], # node 1
[1., 0.], [1., -2.]] # node 2
expected_prediction = [[0.0, 1.0], [0.0, 1.0],
[0.0, 1.0], [0.0, 1.0]]
hparams = tensor_forest.ForestHParams(
num_classes=2,
num_features=2,
num_trees=1,
max_nodes=1000,
split_after_samples=25).fill()
tree_weight = {'decisionTree':
{'nodes':
[{'binaryNode':
{'rightChildId': 2,
'leftChildId': 1,
'inequalityLeftChildTest':
{'featureId': {'id': '0'},
'threshold': {'floatValue': 0}}}},
{'leaf': {'vector':
{'value': [{'floatValue': 0.0},
{'floatValue': 1.0}]}},
'nodeId': 1},
{'leaf': {'vector':
{'value': [{'floatValue': 0.0},
{'floatValue': 1.0}]}},
'nodeId': 2}]}}
restored_tree_param = ParseDict(tree_weight,
_tree_proto.Model()).SerializeToString()
graph_builder = tensor_forest.RandomForestGraphs(hparams,
[restored_tree_param])
probs, paths, var = graph_builder.inference_graph(input_data)
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))
with self.test_session():
variables.global_variables_initializer().run()
resources.initialize_resources(resources.shared_resources()).run()
self.assertEquals(probs.eval().shape, (4, 2))
self.assertEquals(probs.eval().tolist(), expected_prediction)

def testTrainingConstructionClassificationSparse(self):
input_data = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]],
Expand Down