Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge master into r1.0 to fetch all the fixes. #6672

Merged
merged 167 commits into from Jan 5, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
167 commits
Select commit Hold shift + click to select a range
1b26bb5
Merge pull request #6463 from gunan/cp
gunan Dec 22, 2016
e2b8784
Change the protobuf dependency from == 3.1.0 to >= 3.1.0 so it doesn't
Dec 8, 2016
e6ec1e2
Update version string to 0.12.1.
gunan Dec 25, 2016
4d924e7
Merge pull request #6491 from gunan/patch
caisq Dec 25, 2016
1a56115
fix spelling/insert correct reference to udacity notebook
rasbt Dec 27, 2016
bb890ea
Make TensorFlow build libcurl from scratch
jart Dec 28, 2016
54d04c6
LinearOperatorComposition: Update concat --> concat_v2.
langmore Dec 28, 2016
814c401
fix a typo.
tensorflower-gardener Dec 28, 2016
152092f
Internal only changes
gunan Dec 28, 2016
ea6ef44
Fix an error in depth_to_space doc (#6535)
gaohuazuo Dec 28, 2016
6af8b79
Merge remote-tracking branch 'gtf/r0.12' into merge
gunan Dec 28, 2016
b61c142
expose contrib.nn module.
martinwicke Dec 28, 2016
1b3d192
Reenable building depthwise_conv_op and depthwise_conv_grad_op on win…
gunan Dec 28, 2016
567973f
Fix: run buildifier on tensorflow/core/kernels/BUILD
gunan Dec 28, 2016
96645ce
Merge pull request #6542 from gunan/merge
gunan Dec 28, 2016
b04d568
tfdbg: add auto-generated Python API doc to gen_docs_combined.py
caisq Dec 28, 2016
4399506
Update generated Python Op docs.
tensorflower-gardener Dec 28, 2016
95a954a
Remove default support for Google Cloud Platform that was introduced …
tensorflower-gardener Dec 28, 2016
8a23c40
Merge pull request #6543 from gunan/windows_fix
gunan Dec 28, 2016
5a4ee93
Add missing numpy and six deps
jart Dec 28, 2016
2d8a854
Seal tf.losses.
Dec 28, 2016
4dfc5d6
Minor fix: allowing min_object_overlap to be 0.
tensorflower-gardener Dec 28, 2016
f7621a6
Update ops-related pbtxt files.
tensorflower-gardener Dec 28, 2016
a1e475c
Fix: Create a unique temporary directory everytime a test runs.
gunan Dec 28, 2016
7604efd
Update generated Python Op docs.
tensorflower-gardener Dec 28, 2016
46ad56e
Android: Copy libs built bazel in android_nightly.sh to output dir (#…
andrewharp Dec 28, 2016
eb66bc7
Enable running android_nightly.sh through ci_parameterized_build.sh
gunan Dec 28, 2016
ce8ca69
Update generated Python Op docs.
tensorflower-gardener Dec 28, 2016
cc0de74
Remove superfluous Python deps
jart Dec 29, 2016
ac28ae0
Merge pull request #6545 from gunan/android
yifeif Dec 29, 2016
d322c05
Add missing std::fixed in StatSummarizer to ensure integers get sorte…
andrewharp Dec 29, 2016
a081f4b
Bug fix: Make `get_local_variable` accessible in the opensource Tenso…
tensorflower-gardener Dec 29, 2016
a1a3b0c
Android: show inference stats on debug screen in demo (accessed with …
andrewharp Dec 29, 2016
513ba3d
Silenced the 'pushd'
z-a-f Dec 29, 2016
3011714
Remove manual nvidia-docker instructions.
gunan Dec 29, 2016
5883a0f
Fix: make sure android_nightly is a recognized container type in ci s…
gunan Dec 29, 2016
e28db9f
Move most foo.BUILD files into third_party
jart Dec 29, 2016
5ecc547
Merge pull request #6559 from gunan/android
yifeif Dec 29, 2016
4d120b7
Use std::move for functions in gpu EventMgr.
tensorflower-gardener Dec 29, 2016
68d66c8
tfdbg doc: rearrange Q&A sections, improve links to docs and external…
caisq Dec 29, 2016
1056e5c
tfdbg: make debug_py target public
caisq Dec 29, 2016
8738a93
Fix 'import test_ops' error.
tensorflower-gardener Dec 29, 2016
c3f6a07
tfdbg CLI: add more clickable links and formatting to analyzer outputs
caisq Dec 29, 2016
714e2a5
Remove stale BUILD files.
yifeif Dec 29, 2016
9c62c06
Include nn in cmake.
tensorflower-gardener Dec 29, 2016
d8cdf23
ResourceOpKernel destructor can't always assume the resource is still…
tensorflower-gardener Dec 29, 2016
ac11a9d
Create a copy of the collections dict's items before doing iteration …
tensorflower-gardener Dec 29, 2016
ab43622
Internal only changes
tensorflower-gardener Dec 29, 2016
25f072f
Merge pull request #6557 from gunan/patch
gunan Dec 29, 2016
6cbc369
Do not remove the test temporary directory after test completes.
gunan Dec 29, 2016
fdd0127
Merge commit for internal changes
yifeif Dec 29, 2016
b720a8d
Fix: Android_nightly build uses android docker container.
gunan Dec 29, 2016
0295391
Adds V3 version of TensorArray ops. All use resource handles and Tens…
tensorflower-gardener Dec 29, 2016
b95789d
Add attention decoder functions for dynamic_rnn_decoder.
tensorflower-gardener Dec 29, 2016
0e05433
Make android_nightly.sh executable.
gunan Dec 30, 2016
a2eba82
Merge pull request #6569 from yifeif/branch_143206951
gunan Dec 30, 2016
f2148d0
Different deprecation annotation for TensorArrayV2 things.
tensorflower-gardener Dec 29, 2016
c506a98
Update ops-related pbtxt files.
tensorflower-gardener Dec 29, 2016
056348f
Update ops-related pbtxt files.
tensorflower-gardener Dec 30, 2016
e133c06
Fix complexity problem in transform node iteration
petewarden Dec 30, 2016
2e22f1b
Merge pull request #6556 from zafartahirov/silence-pushd
caisq Dec 30, 2016
fef50bf
Internal-only change
eliben Dec 30, 2016
7815fcb
Internal-only change
eliben Dec 30, 2016
e121667
Remove so many more hourglass imports
jart Dec 30, 2016
00c45ea
Update generated Python Op docs.
tensorflower-gardener Dec 30, 2016
7455e25
Add missing gradients to tf.ceil and tf.round.
tensorflower-gardener Dec 30, 2016
1243fbe
Deal with the case where _SwitchGrad() is not called the first time f…
tensorflower-gardener Dec 30, 2016
1f46c9f
Add more display options to benchmark, including FLOPs
petewarden Dec 30, 2016
9f360c9
tfdbg CLI: fix bugs related to terminal size
caisq Dec 31, 2016
8501162
Added more fine-grained shape inference for TensorArray such that par…
tensorflower-gardener Dec 31, 2016
9a04727
Merge commit for internal changes
caisq Dec 31, 2016
1435562
Get rid of the last references to FunctionDef.node before we
tensorflower-gardener Dec 31, 2016
fa4ba83
Merge pull request #6583 from caisq/branch_143288671
gunan Jan 1, 2017
5520cd9
1% -> 5% in comment
mbasilyan Jan 1, 2017
eeb13cd
Fixes #6549: Added missing keep_checkpoint_every_n_hours flag
terrytangyuan Jan 1, 2017
59d341d
Merge pull request #6587 from mbasilyan/patch-1
yifeif Jan 1, 2017
55b0159
Merge pull request #6588 from terrytangyuan/run_config_flag
yifeif Jan 1, 2017
90d3b00
Detect and match against full cuda and cudnn versions.
davidzchen Jan 3, 2017
46d2c28
Merge changes from github.
tensorflower-gardener Jan 3, 2017
56da126
Update generated Python Op docs.
tensorflower-gardener Jan 3, 2017
c955f98
Make *args in sv.loop example an iterable
carlthome Jan 3, 2017
eb020c9
Corrected typo.
tensorflower-gardener Jan 3, 2017
e39df7d
Internal only change.
dm-jrae Jan 3, 2017
9bdea3d
For ci_build rename ANDROID_NIGHTLY to ANDROID_FULL.
gunan Jan 3, 2017
346e3d1
Ensure that _CompressHistogram() always returns a CompressedHistogram…
tensorflower-gardener Jan 3, 2017
42a4d5c
Merge commit for internal changes
rohan100jain Jan 3, 2017
f73fe0e
Add a benchmark for TensorFlow RPC performance
Jan 3, 2017
542e4dc
tfdbg doc: emphasize the new required BUILD dependnecy
caisq Jan 3, 2017
6a04c10
Fixed incorrect example for tf.while_loop().
tensorflower-gardener Jan 3, 2017
bb5f900
Merge pull request #6608 from davidzchen/versions
rohan100jain Jan 3, 2017
fafbb33
Merge pull request #6571 from gunan/android
rohan100jain Jan 3, 2017
1514d36
Fix `sample` shape hints and remove `sample_n`.
jvdillon Jan 3, 2017
d13c89d
Merge pull request #6617 from carlthome/patch-3
gunan Jan 3, 2017
50d0651
Merge pull request #6516 from rasbt/udacity_nb3
rohan100jain Jan 3, 2017
7c36309
Merge pull request #6619 from rohan100jain/branch_143464290
rohan100jain Jan 3, 2017
ea579f1
Update generated Python Op docs.
tensorflower-gardener Jan 3, 2017
81d9a24
Update pylintrc.
tensorflower-gardener Jan 3, 2017
7f62ba6
Update documentation for parse_example vs parse_single_example.
tensorflower-gardener Jan 3, 2017
1055b6a
Handle non-tensor args for predictions and labels.
tensorflower-gardener Jan 3, 2017
93fa85e
Update generated Python Op docs.
tensorflower-gardener Jan 3, 2017
2add2f1
Pass Estimator model_dir to the model_fn.
tensorflower-gardener Jan 3, 2017
92b5d2f
Update generated Python Op docs.
tensorflower-gardener Jan 3, 2017
e24d017
Restrict weights rank to be the same as the broadcast target, to avoi…
tensorflower-gardener Jan 3, 2017
2c0fa4e
Remove unused FLAGS.
Jan 3, 2017
ee1f819
Add control edge support to TensorId.
skye Jan 3, 2017
74edc58
Fix a bug in sparse_softmax_cross_entropy for weights of unspecified …
tensorflower-gardener Jan 3, 2017
a023d0b
Update generated Python Op docs.
tensorflower-gardener Jan 3, 2017
35f8b1f
Fix freeze_graph.
tensorflower-gardener Jan 4, 2017
109c03d
Android: add Timer utility class for measuring cpu and wall time.
andrewharp Jan 4, 2017
6ec984e
Adds the following new ops:
tensorflower-gardener Jan 4, 2017
d352573
Update ops-related pbtxt files.
tensorflower-gardener Jan 4, 2017
d4109cb
Optimize im2col section of quantized convolution
petewarden Jan 4, 2017
8a6014b
Fix usage of tensorflow namespace in graph_to_dot
petewarden Jan 4, 2017
7c63520
Update generated Python Op docs.
tensorflower-gardener Jan 4, 2017
6703501
Added Experiment integration tests with custom Estimator, linear/dnn/…
ispirmustafa Jan 4, 2017
271b7f3
tfdbg CLI: let list_tensors (lt) output display dump file size
caisq Jan 4, 2017
99125e7
tfdbg doc: minor fix re. command-line flags
caisq Jan 4, 2017
48da18f
Update generated Python Op docs.
tensorflower-gardener Jan 4, 2017
b217f61
Remove unused TF_NEED_SYCL from ./configure.
pwnall Jan 4, 2017
51e5d17
Convert tf.flags usage to argparse. Move use of FLAGS globals into m…
Jan 4, 2017
97866c1
Automated rollback of change 143523842
tensorflower-gardener Jan 4, 2017
e0ec343
LinearOperator (base class), prefer statically defined shape if avail…
langmore Jan 4, 2017
8cb009b
Updated description of CheckpointSaver.
tensorflower-gardener Jan 4, 2017
843974d
Fix typo (though --> through) in tf.placeholder_with_default().
tensorflower-gardener Jan 4, 2017
8cfffcf
Update generated Python Op docs.
tensorflower-gardener Jan 4, 2017
e78b994
Update ops-related pbtxt files.
tensorflower-gardener Jan 4, 2017
ea29616
Fix parsing of Python command-line arguments in tests.
hawkinsp Jan 4, 2017
012800e
Change for internal compatibility.
tensorflower-gardener Jan 4, 2017
3b5f50d
Add support for byte-level native access for Android TensorFlow.
tensorflower-gardener Jan 4, 2017
1118de0
Mark gemmlowp result as initialized.
Jan 4, 2017
bd97023
Switch tf-learn BaseEstimator.evaluate() to using evaluation.evaluate…
caisq Jan 4, 2017
4982c62
Add deprecation warnings to tf.neg and prepare for deprecation warnin…
aselle Jan 4, 2017
e9a28c5
Remove mnist dependency in windows, as it causes linker issues.
gunan Jan 3, 2017
ffd7338
Update generated Python Op docs.
tensorflower-gardener Jan 4, 2017
354972d
Move SIMD feature warnings to the first use of intensive CPU computat…
petewarden Jan 4, 2017
2482564
Adds V2 versions of Queue and Reader ops using ResourceHandles.
tensorflower-gardener Jan 4, 2017
d7b1d0a
Update ops-related pbtxt files.
tensorflower-gardener Jan 4, 2017
e1eae19
Android: add debug-specific overlay for detection activity that can b…
andrewharp Jan 4, 2017
2522285
Allow fully dynamic batch/event overrides.
jvdillon Jan 4, 2017
fc8dd9f
Update generated Python Op docs.
tensorflower-gardener Jan 4, 2017
70d4f7e
LinearOperatorIdentity added to tensorflow/contrib/linalg/
langmore Jan 4, 2017
8d072f5
Android: add support for object names in MultiboxTracker
andrewharp Jan 4, 2017
d67c09d
Merge pull request #6621 from gunan/windows
gunan Jan 4, 2017
f74af69
Update generated Python Op docs.
tensorflower-gardener Jan 4, 2017
ddedae6
Make srcs and deps arguments to tf_cuda_cc_test build rule optional.
hawkinsp Jan 4, 2017
fcc319a
Remove unused FLAGS variable.
Jan 4, 2017
2eb1604
Make Empty Op Stateful.
tensorflower-gardener Jan 4, 2017
de053cf
Update ops-related pbtxt files.
tensorflower-gardener Jan 4, 2017
d3822f0
Merge pull request #6628 from pwnall/fix_configure
gunan Jan 4, 2017
73eff47
Update callers of array_ops.concat to call array_ops.concat_v2 instea…
tensorflower-gardener Jan 4, 2017
02d2385
Include stream_executor headers in pip package include directory.
Jan 4, 2017
b17f1e2
Removing comments for investigation of the root cause of test toleran…
tensorflower-gardener Jan 4, 2017
eba10b7
Defer optimizer function run in linear classifier until apply gradien…
tensorflower-gardener Jan 4, 2017
e43b9d8
Remove a few ununsed functions.
lilao Jan 5, 2017
f6d47fa
Improved documentation for OpenCL setup
benoitsteiner Jan 5, 2017
a31acbe
Remove pending inputs from RunState of DirectSession::Run.
Jan 5, 2017
37b430c
Moving FinalOpsHook into basic_session_run_hooks.
Jan 5, 2017
bf00bcc
Provide multiple implementations of RPC requests on the feed path.
mrry Jan 5, 2017
1628abf
Fixing problem with restoring scope with partitioned variables.
Jan 5, 2017
d954169
Update generated Python Op docs.
tensorflower-gardener Jan 5, 2017
333dc32
Change arg order for {softmax,sparse_softmax,sigmoid}_cross_entropy_w…
martinwicke Jan 5, 2017
b9b7b88
Update generated Python Op docs.
tensorflower-gardener Jan 5, 2017
ecf97ee
worspace.bzl uses zlib permalink (#6612)
ahundt Jan 5, 2017
7c97527
Make labeled_tensor use tf.contrib.nn.deprecated_flipped_* versions o…
martinwicke Jan 5, 2017
83a98cc
Merge commit for internal changes
caisq Jan 5, 2017
4433079
Merge pull request #6667 from caisq/branch_143639671
rohan100jain Jan 5, 2017
f07e1ab
Merge remote-tracking branch 'tensorflow/master' into r1.0
gunan Jan 5, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
5 changes: 2 additions & 3 deletions configure
Expand Up @@ -4,7 +4,7 @@ set -e
set -o pipefail

# Find out the absolute path to where ./configure resides
pushd `dirname $0` #> /dev/null
pushd `dirname $0` > /dev/null
SOURCE_BASE_DIR=`pwd -P`
popd > /dev/null

Expand Down Expand Up @@ -145,7 +145,7 @@ while [ "$TF_NEED_CUDA" == "" ]; do
done

export TF_NEED_CUDA
export TF_NEED_SYCL
export TF_NEED_OPENCL
if [[ "$TF_NEED_CUDA" == "0" ]] && [[ "$TF_NEED_OPENCL" == "0" ]]; then
echo "Configuration finished"
bazel_clean_and_fetch
Expand Down Expand Up @@ -465,7 +465,6 @@ while true; do
COMPUTECPP_TOOLKIT_PATH=""
done

export TF_NEED_OPENCL
# end of if "$TF_NEED_OPENCL" == "1"
fi

Expand Down
1 change: 1 addition & 0 deletions tensorflow/BUILD
Expand Up @@ -133,6 +133,7 @@ filegroup(
"//tensorflow/contrib/losses:all_files",
"//tensorflow/contrib/metrics:all_files",
"//tensorflow/contrib/ndlstm:all_files",
"//tensorflow/contrib/nn:all_files",
"//tensorflow/contrib/opt:all_files",
"//tensorflow/contrib/rnn:all_files",
"//tensorflow/contrib/seq2seq:all_files",
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/contrib/BUILD
Expand Up @@ -38,6 +38,7 @@ py_library(
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/metrics:metrics_py",
"//tensorflow/contrib/ndlstm",
"//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/opt:opt_py",
"//tensorflow/contrib/quantization:quantization_py",
"//tensorflow/contrib/rnn:rnn_py",
Expand All @@ -47,8 +48,8 @@ py_library(
"//tensorflow/contrib/solvers:solvers_py",
"//tensorflow/contrib/specs",
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
"//tensorflow/contrib/tensor_forest:tensor_forest_py",
"//tensorflow/contrib/tensor_forest/hybrid:ops_lib",
"//tensorflow/contrib/tensor_forest:init_py",
"//tensorflow/contrib/tensor_forest/hybrid:ops_lib", # XXX: no ref but need for pip
"//tensorflow/contrib/tensorboard",
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/contrib/tfprof",
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/__init__.py
Expand Up @@ -42,6 +42,7 @@
from tensorflow.contrib import lookup
from tensorflow.contrib import losses
from tensorflow.contrib import metrics
from tensorflow.contrib import nn
from tensorflow.contrib import opt
from tensorflow.contrib import quantization
from tensorflow.contrib import rnn
Expand Down
Expand Up @@ -76,23 +76,31 @@ public TensorFlowInferenceInterface() {
*/
public native int runInference(String[] outputNames);

/**
* Whether to collect and log stats to logcat during inference via StepStats and StatSummarizer.
* This should only be enabled when needed, as it will add overhead.
*/
public native void enableStatLogging(boolean enabled);

/** Returns the last stat summary string if logging is enabled. */
public native String getStatString();

/**
* Cleans up the native variables associated with this Object. initializeTensorFlow() can then
* be called again to initialize a new session.
*
*/
public native void close();

// Methods for creating a native Tensor and filling it with values.
public native void fillNodeFloat(String inputName, int[] dims, float[] values);

public native void fillNodeInt(String inputName, int[] dims, int[] values);

public native void fillNodeDouble(String inputName, int[] dims, double[] values);
public native void fillNodeByte(String inputName, int[] dims, byte[] values);

public native void readNodeFloat(String outputName, float[] values);
public native void readNodeInt(String outputName, int[] values);
public native void readNodeDouble(String outputName, double[] values);
public native void readNodeByte(String outputName, byte[] values);

/**
* Canary method solely for determining if the tensorflow_inference native library should be
Expand Down
73 changes: 61 additions & 12 deletions tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
Expand Up @@ -53,6 +53,9 @@ struct SessionVariables {
int num_runs = 0;
int64 timing_total_us = 0;

bool log_stats = false;
StatSummarizer* summarizer = nullptr;

InputMap input_tensors;
std::vector<std::string> output_tensor_names;
std::vector<tensorflow::Tensor> output_tensors;
Expand Down Expand Up @@ -129,6 +132,10 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorFlow)(
LOG(INFO) << "GraphDef loaded from " << model_str << " with "
<< tensorflow_graph.node_size() << " nodes.";

// Whether or not stat logging is currently enabled, the StatSummarizer must
// be initialized here with the GraphDef while it is available.
vars->summarizer = new StatSummarizer(tensorflow_graph);

LOG(INFO) << "Creating TensorFlow graph from GraphDef.";
tensorflow::Status s = session->Create(tensorflow_graph);

Expand Down Expand Up @@ -193,8 +200,28 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(runInference)(
}

vars->output_tensors.clear();
s = vars->session->Run(input_tensors, vars->output_tensor_names, {},
&(vars->output_tensors));

if (vars->log_stats) {
RunOptions run_options;
run_options.set_trace_level(RunOptions::FULL_TRACE);
RunMetadata run_metadata;

s = vars->session->Run(run_options, input_tensors,
vars->output_tensor_names, {},
&(vars->output_tensors), &run_metadata);

assert(run_metadata.has_step_stats());
const StepStats& step_stats = run_metadata.step_stats();
vars->summarizer->ProcessStepStats(step_stats);

// Print the full output string, not just the abbreviated one returned by
// getStatString().
vars->summarizer->PrintStepStats();
} else {
s = vars->session->Run(input_tensors, vars->output_tensor_names, {},
&(vars->output_tensors));
}

end_time = CurrentWallTimeUs();
const int64 elapsed_time_inf = end_time - start_time;
vars->timing_total_us += elapsed_time_inf;
Expand All @@ -208,6 +235,24 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(runInference)(
return s.code();
}

JNIEXPORT void JNICALL TENSORFLOW_METHOD(enableStatLogging)(
JNIEnv* env, jobject thiz, jboolean enableStatLogging) {
SessionVariables* vars = GetSessionVars(env, thiz);
vars->log_stats = enableStatLogging;
}

JNIEXPORT jstring JNICALL TENSORFLOW_METHOD(getStatString)(JNIEnv* env,
jobject thiz) {
// Return an abbreviated stat string suitable for displaying on screen.
SessionVariables* vars = GetSessionVars(env, thiz);
std::stringstream ss;
ss << vars->summarizer->GetStatsByMetric("Top 10 CPU",
StatSummarizer::BY_TIME, 10);
ss << vars->summarizer->GetStatsByNodeType();
ss << vars->summarizer->ShortSummary();
return env->NewStringUTF(ss.str().c_str());
}

JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
SessionVariables* vars = GetSessionVars(env, thiz);

Expand All @@ -216,6 +261,8 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
LOG(ERROR) << "Error closing session: " << s;
}

delete vars->summarizer;

mutex_lock l(mutex_);
std::map<int64, SessionVariables*>& sessions = *GetSessionsSingleton();
sessions.erase(vars->id);
Expand All @@ -225,7 +272,7 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
}

// TODO(andrewharp): Use memcpy to fill/read nodes.
#define FILL_NODE_METHOD(DTYPE, JAVA_DTYPE, TENSOR_DTYPE) \
#define FILL_NODE_METHOD(DTYPE, JAVA_DTYPE, CTYPE, TENSOR_DTYPE) \
FILL_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) { \
SessionVariables* vars = GetSessionVars(env, thiz); \
jboolean iCopied = JNI_FALSE; \
Expand All @@ -237,7 +284,7 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
} \
env->ReleaseIntArrayElements(dims, dim_vals, JNI_ABORT); \
tensorflow::Tensor input_tensor(TENSOR_DTYPE, shape); \
auto tensor_mapped = input_tensor.flat<JAVA_DTYPE>(); \
auto tensor_mapped = input_tensor.flat<CTYPE>(); \
j##JAVA_DTYPE* values = env->Get##DTYPE##ArrayElements(arr, &iCopied); \
j##JAVA_DTYPE* value_ptr = values; \
const int array_size = env->GetArrayLength(arr); \
Expand All @@ -253,14 +300,14 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
vars->input_tensors[input_name] = input_pair; \
}

#define READ_NODE_METHOD(DTYPE, JAVA_DTYPE) \
#define READ_NODE_METHOD(DTYPE, JAVA_DTYPE, CTYPE) \
READ_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) { \
SessionVariables* vars = GetSessionVars(env, thiz); \
Tensor* t = GetTensor(env, thiz, node_name_jstring); \
if (t == nullptr) { \
return -1; \
} \
auto tensor_mapped = t->flat<JAVA_DTYPE>(); \
auto tensor_mapped = t->flat<CTYPE>(); \
jboolean iCopied = JNI_FALSE; \
j##JAVA_DTYPE* values = env->Get##DTYPE##ArrayElements(arr, &iCopied); \
j##JAVA_DTYPE* value_ptr = values; \
Expand All @@ -273,10 +320,12 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
return 0; \
}

FILL_NODE_METHOD(Float, float, tensorflow::DT_FLOAT)
FILL_NODE_METHOD(Int, int, tensorflow::DT_INT32)
FILL_NODE_METHOD(Double, double, tensorflow::DT_DOUBLE)
FILL_NODE_METHOD(Float, float, float, tensorflow::DT_FLOAT)
FILL_NODE_METHOD(Int, int, int, tensorflow::DT_INT32)
FILL_NODE_METHOD(Double, double, double, tensorflow::DT_DOUBLE)
FILL_NODE_METHOD(Byte, byte, uint8_t, tensorflow::DT_UINT8)

READ_NODE_METHOD(Float, float)
READ_NODE_METHOD(Int, int)
READ_NODE_METHOD(Double, double)
READ_NODE_METHOD(Float, float, float)
READ_NODE_METHOD(Int, int, int)
READ_NODE_METHOD(Double, double, double)
READ_NODE_METHOD(Byte, byte, uint8_t)
8 changes: 8 additions & 0 deletions tensorflow/contrib/android/jni/tensorflow_inference_jni.h
Expand Up @@ -48,15 +48,23 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorFlow)(
JNIEXPORT jint JNICALL TENSORFLOW_METHOD(runInference)(
JNIEnv* env, jobject thiz, jobjectArray output_name_strings);

JNIEXPORT void JNICALL TENSORFLOW_METHOD(enableStatLogging)(
JNIEnv* env, jobject thiz, jboolean enableStatLogging);

JNIEXPORT jstring JNICALL TENSORFLOW_METHOD(getStatString)(JNIEnv* env,
jobject thiz);

JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz);

FILL_NODE_SIGNATURE(Float, float);
FILL_NODE_SIGNATURE(Int, int);
FILL_NODE_SIGNATURE(Double, double);
FILL_NODE_SIGNATURE(Byte, byte);

READ_NODE_SIGNATURE(Float, float);
READ_NODE_SIGNATURE(Int, int);
READ_NODE_SIGNATURE(Double, double);
READ_NODE_SIGNATURE(Byte, byte);

#ifdef __cplusplus
} // extern "C"
Expand Down
64 changes: 56 additions & 8 deletions tensorflow/contrib/bayesflow/BUILD
Expand Up @@ -26,6 +26,8 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//third_party/py/numpy",
"@six_archive//:six",
],
)

Expand All @@ -35,9 +37,16 @@ cuda_py_test(
srcs = ["python/kernel_tests/entropy_test.py"],
additional_deps = [
":bayesflow_py",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
)

Expand All @@ -47,9 +56,17 @@ cuda_py_test(
srcs = ["python/kernel_tests/stochastic_variables_test.py"],
additional_deps = [
":bayesflow_py",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
)

Expand All @@ -59,8 +76,13 @@ cuda_py_test(
srcs = ["python/kernel_tests/monte_carlo_test.py"],
additional_deps = [
":bayesflow_py",
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
Expand All @@ -71,9 +93,13 @@ cuda_py_test(
srcs = ["python/kernel_tests/special_math_test.py"],
additional_deps = [
":bayesflow_py",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
)

Expand All @@ -83,8 +109,14 @@ cuda_py_test(
srcs = ["python/kernel_tests/stochastic_graph_test.py"],
additional_deps = [
":bayesflow_py",
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
Expand All @@ -95,9 +127,15 @@ cuda_py_test(
srcs = ["python/kernel_tests/variational_inference_test.py"],
additional_deps = [
":bayesflow_py",
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
)

Expand All @@ -107,7 +145,11 @@ cuda_py_test(
srcs = ["python/kernel_tests/stochastic_tensor_test.py"],
additional_deps = [
":bayesflow_py",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
Expand All @@ -119,9 +161,15 @@ cuda_py_test(
srcs = ["python/kernel_tests/stochastic_gradient_estimators_test.py"],
additional_deps = [
":bayesflow_py",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
)

Expand Down