Skip to content

Commit

Permalink
Android: make ObjecTracker field lookup safe for 64-bit builds.
Browse files Browse the repository at this point in the history
Change: 143813950
  • Loading branch information
andrewharp authored and tensorflower-gardener committed Jan 6, 2017
1 parent 81560cd commit 06a5540
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
21 changes: 12 additions & 9 deletions tensorflow/examples/android/jni/object_tracking/jni_utils.h
Expand Up @@ -16,40 +16,43 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_

#include <stdint.h>

#include <android/log.h>

#include "tensorflow/examples/android/jni/object_tracking/utils.h"

// The JniIntField class is used to access Java fields from native code. This
// The JniLongField class is used to access Java fields from native code. This
// technique of hiding pointers to native objects in opaque Java fields is how
// the Android hardware libraries work. This reduces the amount of static
// native methods and makes it easier to manage the lifetime of native objects.
class JniIntField {
class JniLongField {
public:
JniIntField(const char* field_name) : field_name_(field_name), field_ID_(0) {}
JniLongField(const char* field_name)
: field_name_(field_name), field_ID_(0) {}

int get(JNIEnv* env, jobject thiz) {
int64_t get(JNIEnv* env, jobject thiz) {
if (field_ID_ == 0) {
jclass cls = env->GetObjectClass(thiz);
CHECK_ALWAYS(cls != 0, "Unable to find class");
field_ID_ = env->GetFieldID(cls, field_name_, "I");
field_ID_ = env->GetFieldID(cls, field_name_, "J");
CHECK_ALWAYS(field_ID_ != 0,
"Unable to find field %s. (Check proguard cfg)", field_name_);
}

return env->GetIntField(thiz, field_ID_);
return env->GetLongField(thiz, field_ID_);
}

void set(JNIEnv* env, jobject thiz, int value) {
void set(JNIEnv* env, jobject thiz, int64_t value) {
if (field_ID_ == 0) {
jclass cls = env->GetObjectClass(thiz);
CHECK_ALWAYS(cls != 0, "Unable to find class");
field_ID_ = env->GetFieldID(cls, field_name_, "I");
field_ID_ = env->GetFieldID(cls, field_name_, "J");
CHECK_ALWAYS(field_ID_ != 0,
"Unable to find field %s (Check proguard cfg)", field_name_);
}

env->SetIntField(thiz, field_ID_, value);
env->SetLongField(thiz, field_ID_, value);
}

private:
Expand Down
Expand Up @@ -37,7 +37,7 @@ namespace tf_tracking {
#define OBJECT_TRACKER_METHOD(METHOD_NAME) \
Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME // NOLINT

JniIntField object_tracker_field("nativeObjectTracker");
JniLongField object_tracker_field("nativeObjectTracker");

ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) {
ObjectTracker* const object_tracker =
Expand Down
Expand Up @@ -594,12 +594,10 @@ public synchronized TrackedObject trackObject(final RectF position, final byte[]
return new TrackedObject(position, lastTimestamp, frameData);
}

/*********************** NATIVE CODE *************************************/
/** ********************* NATIVE CODE ************************************ */

/**
* This will contain an opaque pointer to the native ObjectTracker
*/
private int nativeObjectTracker;
/** This will contain an opaque pointer to the native ObjectTracker */
private long nativeObjectTracker;

private native void initNative(int imageWidth, int imageHeight, boolean alwaysTrack);

Expand Down

0 comments on commit 06a5540

Please sign in to comment.