Skip to content

Commit

Permalink
Move MonitoredSession and related utilities from tf.contrib.learn to …
Browse files Browse the repository at this point in the history
…tf.train

Change: 135010812
  • Loading branch information
ispirmustafa authored and tensorflower-gardener committed Oct 3, 2016
1 parent b1bd36b commit 82b907e
Show file tree
Hide file tree
Showing 20 changed files with 1,870 additions and 1,426 deletions.
58 changes: 6 additions & 52 deletions tensorflow/contrib/framework/python/ops/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework.load_library import load_op_library
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.framework.load_library import load_op_library
from tensorflow.python.platform import resource_loader
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import training_util


__all__ = ['add_model_variable',
Expand Down Expand Up @@ -78,31 +79,15 @@ def zero_initializer(ref, use_locking=True, name="zero_initializer"):
assert _variable_ops, "Could not load _variable_ops.so"
return gen_variable_ops.zero_initializer(ref, name=name)


# shape function for _ZeroInitializerOp
@ops.RegisterShape("ZeroInitializer")
def _ZeroInitializerShape(op):
var_shape = op.inputs[0].get_shape()
return [var_shape]

def assert_global_step(global_step_tensor):
"""Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.
Args:
global_step_tensor: `Tensor` to test.
"""
if not (isinstance(global_step_tensor, variables.Variable) or
isinstance(global_step_tensor, ops.Tensor)):
raise TypeError('Existing "global_step" must be a Variable or Tensor.')

if not global_step_tensor.dtype.base_dtype.is_integer:
raise TypeError(
'Existing "global_step" does not have integer type: %s' %
global_step_tensor.dtype)

if global_step_tensor.get_shape().ndims != 0:
raise TypeError(
'Existing "global_step" is not scalar: %s' %
global_step_tensor.get_shape())
training_util.assert_global_step(global_step_tensor)


def assert_or_get_global_step(graph=None, global_step_tensor=None):
Expand All @@ -129,39 +114,8 @@ def assert_or_get_global_step(graph=None, global_step_tensor=None):
return global_step_tensor


# TODO(ptucker): Change supervisor to use this when it's migrated to core.
def get_global_step(graph=None):
"""Get the global step tensor.
The global step tensor must be an integer variable. We first try to find it
in the collection `GLOBAL_STEP`, or by name `global_step:0`.
Args:
graph: The graph to find the global step in. If missing, use default graph.
Returns:
The global step variable, or `None` if none was found.
Raises:
TypeError: If the global step tensor has a non-integer type, or if it is not
a `Variable`.
"""
graph = ops.get_default_graph() if graph is None else graph
global_step_tensor = None
global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP)
if len(global_step_tensors) == 1:
global_step_tensor = global_step_tensors[0]
elif not global_step_tensors:
try:
global_step_tensor = graph.get_tensor_by_name('global_step:0')
except KeyError:
return None
else:
logging.error('Multiple tensors in global_step collection.')
return None

assert_global_step(global_step_tensor)
return global_step_tensor
return training_util.get_global_step(graph)


def create_global_step(graph=None):
Expand Down
22 changes: 0 additions & 22 deletions tensorflow/contrib/learn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -331,18 +331,6 @@ py_test(
],
)

py_test(
name = "monitored_session_test",
size = "small",
srcs = ["python/learn/tests/monitored_session_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
],
)

py_test(
name = "monitors_test",
size = "small",
Expand Down Expand Up @@ -703,16 +691,6 @@ py_test(
],
)

py_test(
name = "summary_writer_cache_test",
size = "small",
srcs = ["python/learn/tests/summary_writer_cache_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
)

py_binary(
name = "inspect_checkpoint",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/learn/python/learn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np

# pylint: disable=wildcard-import
from tensorflow.contrib.learn.python.learn import basic_session_run_hooks
from tensorflow.contrib.learn.python.learn import datasets
from tensorflow.contrib.learn.python.learn import estimators
from tensorflow.contrib.learn.python.learn import graph_actions
Expand Down
Loading

0 comments on commit 82b907e

Please sign in to comment.