Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ private static void delete(TFE_Context handle) {
}

static {
TensorFlow.init();
try {
// Ensure that TensorFlow native library and classes are ready to be used
Class.forName("org.tensorflow.TensorFlow");
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,11 @@ private static SaverDef addVariableSaver(Graph graph) {
}

static {
TensorFlow.init();
try {
// Ensure that TensorFlow native library and classes are ready to be used
Class.forName("org.tensorflow.TensorFlow");
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ private static long[] shape(TF_Tensor handle) {
private ByteDataBuffer buffer = null;

static {
TensorFlow.init();
try {
// Ensure that TensorFlow native library and classes are ready to be used
Class.forName("org.tensorflow.TensorFlow");
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,11 @@ private static void validateTags(String[] tags) {
}

static {
TensorFlow.init();
try {
// Ensure that TensorFlow native library and classes are ready to be used
Class.forName("org.tensorflow.TensorFlow");
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ private static void delete(TF_Server nativeHandle) {
private int numJoining;

static {
TensorFlow.init();
try {
// Ensure that TensorFlow native library and classes are ready to be used
Class.forName("org.tensorflow.TensorFlow");
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

/** Static utility methods describing the TensorFlow runtime. */
public final class TensorFlow {

/** Returns the version of the underlying TensorFlow runtime. */
public static String version() {
return TF_Version().getString();
Expand Down Expand Up @@ -106,7 +107,7 @@ private static OpList libraryOpList(TF_Library handle) {
private TensorFlow() {}

/** Load the TensorFlow runtime C library. */
static void init() {
static {
try {
NativeLibrary.load();
} catch (Exception e) {
Expand All @@ -121,8 +122,4 @@ static void init() {
throw e;
}
}

static {
init();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.TensorFlow;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.internal.c_api.TF_TString;
import org.tensorflow.ndarray.impl.buffer.AbstractDataBuffer;
Expand Down Expand Up @@ -132,4 +133,13 @@ void writeNext(byte[] bytes) {
}

private final TF_TString data;

static {
try {
// Ensure that TensorFlow native library and classes are ready to be used
Class.forName("org.tensorflow.TensorFlow");
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.tensorflow.op.Ops;
import org.tensorflow.types.TInt32;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Session;
Expand All @@ -29,6 +29,7 @@
import org.tensorflow.types.TInt32;

public class BooleanMaskTest {

@Test
public void testBooleanMask(){
try (Graph g = new Graph();
Expand All @@ -43,24 +44,24 @@ public void testBooleanMask(){
Operand<TInt32> output1 = BooleanMask.create(scope, input, mask);
Operand<TInt32> output2 = BooleanMask.create(scope, input2, mask, BooleanMask.axis(1));

try (TFloat32 result = (TFloat32) sess.runner().fetch(output1).run().get(0)) {
try (TInt32 result = (TInt32) sess.runner().fetch(output1).run().get(0)) {
// expected shape from Python tensorflow
assertEquals(Shape.of(5), result.shape());
assertEquals(0, result.getFloat(0));
assertEquals(1, result.getFloat(1));
assertEquals(4, result.getFloat(2));
assertEquals(5, result.getFloat(3));
assertEquals(6, result.getFloat(4));
assertEquals(0, result.getInt(0));
assertEquals(1, result.getInt(1));
assertEquals(4, result.getInt(2));
assertEquals(5, result.getInt(3));
assertEquals(6, result.getInt(4));
}

try (TFloat32 result = (TFloat32) sess.runner().fetch(output2).run().get(0)) {
try (TInt32 result = (TInt32) sess.runner().fetch(output2).run().get(0)) {
// expected shape from Python tensorflow
assertEquals(Shape.of(5), result.shape());
assertEquals(0, result.getFloat(0));
assertEquals(1, result.getFloat(1));
assertEquals(4, result.getFloat(2));
assertEquals(5, result.getFloat(3));
assertEquals(6, result.getFloat(4));
assertEquals(Shape.of(1, 5), result.shape());
assertEquals(0, result.getInt(0, 0));
assertEquals(1, result.getInt(0, 1));
assertEquals(4, result.getInt(0, 2));
assertEquals(5, result.getInt(0, 3));
assertEquals(6, result.getInt(0, 4));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.List;
import org.junit.Test;

import org.junit.jupiter.api.Test;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Session;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;

import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.ndarray.Shape;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.nio.charset.StandardCharsets;

import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArrays;
Expand All @@ -39,16 +41,15 @@ public void createScalar() {
assertEquals("Pretty vacant", tensor.getObject());
}

@Test
public void createrScalarLongerThan127() {
TString tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !");
assertNotNull(tensor);
assertEquals(Shape.scalar(), tensor.shape());
assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", tensor.getObject());
}

@Test
public void createrScalarLongerThan127() {
TString tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !");
assertNotNull(tensor);
assertEquals(Shape.scalar(), tensor.shape());
assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", tensor.getObject());
}

@Test
@Test
public void createVector() {
TString tensor = TString.vectorOf("Pretty", "vacant");
assertNotNull(tensor);
Expand Down Expand Up @@ -106,6 +107,7 @@ public void initializingTensorWithRawBytes() {
}

@Test
@Disabled // FIXME This test does not deterministically succeed, so skip it by default
public void testNoLeaks() throws Exception {
// warm up and try to get all JIT compilation done to stabilize memory usage...
for (int i = 0; i < 1000; i++) {
Expand Down