Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 172647355 #13819

Merged
merged 22 commits into from
Oct 18, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
35debbd
Update inception score to match the openAI version from https://githu…
tensorflower-gardener Oct 18, 2017
130ec39
Stub support for retrieving LossFunction by name.
tensorflower-gardener Oct 18, 2017
3c31886
Don't emit fusion computations separately in HloModule::ToString. The…
meheffernan Oct 18, 2017
c9d3377
Make `tf.contrib.distributions` quadrature family parameterized by
jvdillon Oct 18, 2017
bc08226
Fixes test breakage.
alextp Oct 18, 2017
21f68a8
Remove global step read dependency from model_fn. Estimator behavior …
ispirmustafa Oct 18, 2017
1a325f1
More changs to avoid flakes in random_shuffle_queue_test
gunan Oct 18, 2017
6a725f6
Add expected keys to predictor exception if unexpected key detected.
tensorflower-gardener Oct 18, 2017
f5d3bf4
Add TF_GraphGetOpDef() to C API and use in Operation.op_def()
skye Oct 18, 2017
f5ea388
Implement ZlibInputStream::Tell() by keeping track of the number of b…
saxenasaurabh Oct 18, 2017
ef060d9
Upgrade tensorflow pip dependency version to 3.4.0+
caisq Oct 18, 2017
f1603b7
[XLA] Deterministically dump an executable.
yunxing Oct 18, 2017
192f1c2
Fixed work size computation in Split and SplitV ops to avoid integer …
tensorflower-gardener Oct 18, 2017
09ff3f7
Internal change.
tensorflower-gardener Oct 18, 2017
38bcb3c
Bug fixes for fold_constants_lib.
tensorflower-gardener Oct 18, 2017
b7e8533
Adds visibility to sgdr_learning_rate_decay.
tensorflower-gardener Oct 18, 2017
f6968a2
Add logging verbosity to mnist.py
Oct 18, 2017
08aeb0f
Automated g4 rollback of changelist 172336111
allenlavoie Oct 18, 2017
d65f7b9
Correct the docstring to reflect that the values of cols_to_vars are …
tensorflower-gardener Oct 18, 2017
5565aac
Changes MultiLabelHead.create_loss to return a Tensor of size [batch_…
tensorflower-gardener Oct 18, 2017
ff4d978
Fixing conflict in setup.py
Oct 18, 2017
a873cf4
Disabling failing contrib tests.
Oct 18, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions tensorflow/c/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,17 @@ void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
status->status = MessageToBuffer(def, output_graph_def);
}

void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
TF_Buffer* output_op_def, TF_Status* status) {
const OpDef* op_def;
{
mutex_lock l(graph->mu);
status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
if (!status->status.ok()) return;
}
status->status = MessageToBuffer(*op_def, output_op_def);
}

TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
return new TF_ImportGraphDefOptions;
}
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/c/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,13 @@ TF_CAPI_EXPORT extern void TF_GraphToGraphDef(TF_Graph* graph,
TF_Buffer* output_graph_def,
TF_Status* status);

// Returns the serialized OpDef proto with name `op_name`, or a bad status if no
// such op exists. This can return OpDefs of functions copied into the graph.
TF_CAPI_EXPORT extern void TF_GraphGetOpDef(TF_Graph* graph,
const char* op_name,
TF_Buffer* output_op_def,
TF_Status* status);

// TF_ImportGraphDefOptions holds options that can be passed to
// TF_GraphImportGraphDef.
typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions;
Expand Down
21 changes: 21 additions & 0 deletions tensorflow/c/c_api_function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1465,5 +1465,26 @@ TEST_F(CApiFunctionTest, AppendHash) {
ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name());
}

TEST_F(CApiFunctionTest, GetOpDef) {
DefineFunction(func_name_, &func_);
TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);

// Test we can retrieve function OpDef from graph
TF_Buffer* buffer = TF_NewBuffer();
TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);

// Sanity check returned OpDef
string data(static_cast<const char*>(buffer->data), buffer->length);
OpDef op_def;
op_def.ParseFromString(data);
EXPECT_EQ(op_def.name(), func_name_);
EXPECT_EQ(op_def.input_arg_size(), 1);
EXPECT_EQ(op_def.output_arg_size(), 1);

TF_DeleteBuffer(buffer);
}

} // namespace
} // namespace tensorflow
31 changes: 31 additions & 0 deletions tensorflow/c/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
Expand All @@ -50,6 +51,11 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);

namespace {

static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
EXPECT_TRUE(StringPiece(s).contains(expected))
<< "'" << s << "' does not contain '" << expected << "'";
}

TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); }

TEST(CAPI, Status) {
Expand Down Expand Up @@ -837,6 +843,31 @@ TEST(CAPI, ShapeInferenceError) {
TF_DeleteStatus(status);
}

TEST(CAPI, GetOpDef) {
TF_Status* status = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
TF_Buffer* buffer = TF_NewBuffer();

TF_GraphGetOpDef(graph, "Add", buffer, status);
ASSERT_EQ(TF_OK, TF_GetCode(status));
const OpDef* expected_op_def;
TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def));
string expected_serialized;
expected_op_def->SerializeToString(&expected_serialized);
string actual_string(reinterpret_cast<const char*>(buffer->data),
buffer->length);
EXPECT_EQ(expected_serialized, actual_string);

TF_GraphGetOpDef(graph, "MyFakeOp", buffer, status);
EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status));
ExpectHasSubstr(TF_Message(status),
"Op type not registered 'MyFakeOp' in binary");

TF_DeleteBuffer(buffer);
TF_DeleteGraph(graph);
TF_DeleteStatus(status);
}

void StringVectorToArrays(const std::vector<string>& v,
std::unique_ptr<const void* []>* ptrs,
std::unique_ptr<size_t[]>* lens) {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":status",
":status_macros",
":types",
":xla_data_proto",
"//tensorflow/core:lib",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -579,12 +579,14 @@ cc_library(
":shaped_buffer",
":versioned_computation_handle",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor",
],
Expand Down
8 changes: 7 additions & 1 deletion tensorflow/compiler/xla/service/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.

#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
Expand Down Expand Up @@ -82,7 +84,11 @@ Status Executable::DumpSessionModule() {
}
filename = SanitizeFileName(std::move(filename));
string file_path = tensorflow::io::JoinPath(directory_path, filename);
return tensorflow::WriteBinaryProto(env, file_path, session_module);
string result;
TF_RET_CHECK(
tensorflow::SerializeToStringDeterministic(session_module, &result));
return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path,
result);
}

} // namespace xla
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/service/hlo_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ string HloModule::ToString() const {
std::ostringstream s;
s << "HloModule " << name() << ":\n\n";
s << "ENTRY " << entry_computation()->ToString() << "\n\n";
for (const std::unique_ptr<HloComputation>& computation : computations_) {
if (computation.get() != entry_computation()) {
for (const HloComputation* computation : MakeNonfusionComputations()) {
if (computation != entry_computation()) {
s << computation->ToString() << "\n\n";
}
}
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <vector>

#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/contrib/data/python/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ py_test(
size = "small",
srcs = ["batch_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"manual", # b/67958604
],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
Expand Down Expand Up @@ -352,6 +355,9 @@ py_test(
size = "small",
srcs = ["sloppy_transformation_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"manual", # b/67958761
],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.contrib.distributions.python.ops import poisson_lognormal
from tensorflow.contrib.distributions.python.ops import test_util
from tensorflow.python.platform import test
Expand All @@ -32,7 +34,8 @@ def testSampleProbConsistent(self):
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=-2.,
scale=1.1,
quadrature_polynomial_degree=10,
quadrature_grid_and_probs=(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1)
Expand All @@ -42,7 +45,8 @@ def testMeanVariance(self):
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=0.,
scale=1.,
quadrature_polynomial_degree=10,
quadrature_grid_and_probs=(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.02)
Expand All @@ -52,7 +56,8 @@ def testSampleProbConsistentBroadcastScalar(self):
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[0., -0.5],
scale=1.,
quadrature_polynomial_degree=10,
quadrature_grid_and_probs=(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1, atol=0.01)
Expand All @@ -62,7 +67,8 @@ def testMeanVarianceBroadcastScalar(self):
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[0., -0.5],
scale=1.,
quadrature_polynomial_degree=10,
quadrature_grid_and_probs=(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.1, atol=0.01)
Expand All @@ -72,7 +78,8 @@ def testSampleProbConsistentBroadcastBoth(self):
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[[0.], [-0.5]],
scale=[[1., 0.9]],
quadrature_polynomial_degree=10,
quadrature_grid_and_probs=(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1, atol=0.08)
Expand All @@ -82,7 +89,8 @@ def testMeanVarianceBroadcastBoth(self):
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[[0.], [-0.5]],
scale=[[1., 0.9]],
quadrature_polynomial_degree=10,
quadrature_grid_and_probs=(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.1, atol=0.01)
Expand Down
54 changes: 33 additions & 21 deletions tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
: d=0, ..., deg-1 }
```

where, [`grid, w = numpy.polynomial.hermite.hermgauss(deg)`](
where, [e.g., `grid, w = numpy.polynomial.hermite.hermgauss(deg)`](
https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.polynomial.hermite.hermgauss.html)
and `prob = w / sqrt(pi)`.

Expand All @@ -106,14 +106,15 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
pln = ds.PoissonLogNormalQuadratureCompound(
loc=[0., -0.5],
scale=1.,
quadrature_polynomial_degree=10,
quadrature_grid_and_probs=(
np.polynomial.hermite.hermgauss(deg=10)),
validate_args=True)
"""

def __init__(self,
loc,
scale,
quadrature_polynomial_degree=8,
quadrature_grid_and_probs=None,
validate_args=False,
allow_nan_stats=True,
name="PoissonLogNormalQuadratureCompound"):
Expand All @@ -124,8 +125,9 @@ def __init__(self,
the LogNormal prior.
scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
the LogNormal prior.
quadrature_polynomial_degree: Python `int`-like scalar.
Default value: 8.
quadrature_grid_and_probs: Python pair of `list`-like objects representing
the sample points and the corresponding (possibly normalized) weight.
When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
Expand All @@ -138,6 +140,8 @@ def __init__(self,

Raises:
TypeError: if `loc.dtype != scale[0].dtype`.
ValueError: if `quadrature_grid_and_probs is not None` and
`len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
"""
parameters = locals()
with ops.name_scope(name, values=[loc, scale]):
Expand All @@ -153,18 +157,21 @@ def __init__(self,
"loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format(
loc.dtype.name, scale.dtype.name))

self._degree = quadrature_polynomial_degree

grid, prob = np.polynomial.hermite.hermgauss(
deg=quadrature_polynomial_degree)

# It should be that `sum(prob) == sqrt(pi)`, but self-normalization is
# more numerically stable.
prob = prob.astype(dtype.as_numpy_dtype)
prob /= np.linalg.norm(prob, ord=1)
if quadrature_grid_and_probs is None:
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
else:
grid, probs = tuple(quadrature_grid_and_probs)
if len(grid) != len(probs):
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
"same-length list-like objects")
grid = grid.astype(dtype.as_numpy_dtype)
probs = probs.astype(dtype.as_numpy_dtype)
probs /= np.linalg.norm(probs, ord=1)
self._quadrature_grid = grid
self._quadrature_probs = probs

self._mixture_distribution = categorical_lib.Categorical(
logits=np.log(prob),
logits=np.log(probs),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats)

Expand Down Expand Up @@ -210,9 +217,14 @@ def scale(self):
return self._scale

@property
def quadrature_polynomial_degree(self):
"""Polynomial largest exponent used for Gauss-Hermite quadrature."""
return self._degree
def quadrature_grid(self):
"""Quadrature grid points."""
return self._quadrature_grid

@property
def quadrature_probs(self):
"""Quadrature normalized weights."""
return self._quadrature_probs

def _batch_shape_tensor(self):
return array_ops.broadcast_dynamic_shape(
Expand Down Expand Up @@ -242,10 +254,10 @@ def _sample_n(self, n, seed=None):
[batch_size])),
seed=distribution_util.gen_new_seed(
seed, "poisson_lognormal_quadrature_compound"))
# Stride `quadrature_polynomial_degree` for `batch_size` number of times.
# Stride `quadrature_degree` for `batch_size` number of times.
offset = math_ops.range(start=0,
limit=batch_size * self._degree,
delta=self._degree,
limit=batch_size * len(self.quadrature_probs),
delta=len(self.quadrature_probs),
dtype=ids.dtype)
ids += offset
rate = array_ops.gather(
Expand Down