From d5e51b0b6fc1534b11705723a865bfecf9119d11 Mon Sep 17 00:00:00 2001 From: klessard Date: Thu, 13 May 2021 10:50:20 -0400 Subject: [PATCH 1/2] Load TF library before computing TString size --- .../main/java/org/tensorflow/TensorFlow.java | 10 +++++-- .../buffer/ByteSequenceTensorBuffer.java | 5 ++++ .../java/org/tensorflow/WrongEnvTest.java | 2 +- .../tensorflow/op/core/BooleanMaskTest.java | 29 ++++++++++--------- .../op/core/BooleanMaskUpdateTest.java | 3 +- .../org/tensorflow/op/core/IndexingTest.java | 2 +- 6 files changed, 32 insertions(+), 19 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index 338101c962b..3dd93aa382e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -32,6 +32,12 @@ /** Static utility methods describing the TensorFlow runtime. */ public final class TensorFlow { + + /** Make sure all TensorFlow native libraries have been loaded properly */ + public static void init() { + // Do nothing, we'll let the class static initializer load the native library if needed + } + /** Returns the version of the underlying TensorFlow runtime. */ public static String version() { return TF_Version().getString(); @@ -106,7 +112,7 @@ private static OpList libraryOpList(TF_Library handle) { private TensorFlow() {} /** Load the TensorFlow runtime C library. */ - static void init() { + private static void initTensorFlow() { try { NativeLibrary.load(); } catch (Exception e) { @@ -123,6 +129,6 @@ static void init() { } static { - init(); + initTensorFlow(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java index acaeaedbc11..a5d94ae4f28 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java @@ -28,6 +28,7 @@ import org.bytedeco.javacpp.Loader; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerScope; +import org.tensorflow.TensorFlow; import org.tensorflow.ndarray.buffer.DataBuffer; import org.tensorflow.internal.c_api.TF_TString; import org.tensorflow.ndarray.impl.buffer.AbstractDataBuffer; @@ -132,4 +133,8 @@ void writeNext(byte[] bytes) { } private final TF_TString data; + + static { + TensorFlow.init(); // make sure TF library is loaded before working with `TF_TString` native objects + } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java index b2fbc1e794a..18bdeb40e83 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java @@ -20,7 +20,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java index a4d9293ccf8..7c5210c0f2d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java @@ -18,7 +18,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; @@ -29,6 +29,7 @@ import org.tensorflow.types.TInt32; public class BooleanMaskTest { + @Test public void testBooleanMask(){ try (Graph g = new Graph(); @@ -43,24 +44,24 @@ public void testBooleanMask(){ Operand output1 = BooleanMask.create(scope, input, mask); Operand output2 = BooleanMask.create(scope, input2, mask, BooleanMask.axis(1)); - try (TFloat32 result = (TFloat32) sess.runner().fetch(output1).run().get(0)) { + try (TInt32 result = (TInt32) sess.runner().fetch(output1).run().get(0)) { // expected shape from Python tensorflow assertEquals(Shape.of(5), result.shape()); - assertEquals(0, result.getFloat(0)); - assertEquals(1, result.getFloat(1)); - assertEquals(4, result.getFloat(2)); - assertEquals(5, result.getFloat(3)); - assertEquals(6, result.getFloat(4)); + assertEquals(0, result.getInt(0)); + assertEquals(1, result.getInt(1)); + assertEquals(4, result.getInt(2)); + assertEquals(5, result.getInt(3)); + assertEquals(6, result.getInt(4)); } - try (TFloat32 result = (TFloat32) sess.runner().fetch(output2).run().get(0)) { + try (TInt32 result = (TInt32) sess.runner().fetch(output2).run().get(0)) { // expected shape from Python tensorflow - assertEquals(Shape.of(5), result.shape()); - assertEquals(0, result.getFloat(0)); - assertEquals(1, result.getFloat(1)); - assertEquals(4, result.getFloat(2)); - assertEquals(5, result.getFloat(3)); - assertEquals(6, result.getFloat(4)); + assertEquals(Shape.of(1, 5), result.shape()); + assertEquals(0, result.getInt(0, 0)); + assertEquals(1, result.getInt(0, 1)); + assertEquals(4, result.getInt(0, 2)); + assertEquals(5, result.getInt(0, 3)); + assertEquals(6, result.getInt(0, 4)); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java index ab852bbffb2..c2b514bfdb6 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java @@ -19,7 +19,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.List; -import org.junit.Test; + +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java index 6e86573b7cf..9a66d2445d2 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -17,7 +17,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.ndarray.Shape; From e18e55b2007617c08100f8557df943b00578a97f Mon Sep 17 00:00:00 2001 From: klessard Date: Fri, 14 May 2021 11:02:08 -0400 Subject: [PATCH 2/2] Use Class.forName() to statically load TF library once --- .../java/org/tensorflow/EagerSession.java | 7 ++++++- .../src/main/java/org/tensorflow/Graph.java | 7 ++++++- .../main/java/org/tensorflow/RawTensor.java | 7 ++++++- .../java/org/tensorflow/SavedModelBundle.java | 7 ++++++- .../src/main/java/org/tensorflow/Server.java | 7 ++++++- .../main/java/org/tensorflow/TensorFlow.java | 11 +--------- .../buffer/ByteSequenceTensorBuffer.java | 7 ++++++- .../org/tensorflow/types/TStringTest.java | 20 ++++++++++--------- 8 files changed, 48 insertions(+), 25 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 8e7465388a8..dad842f7038 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -404,6 +404,11 @@ private static void delete(TFE_Context handle) { } static { - TensorFlow.init(); + try { + // Ensure that TensorFlow native library and classes are ready to be used + Class.forName("org.tensorflow.TensorFlow"); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index ff805c73b53..7f659b262a6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -1070,6 +1070,11 @@ private static SaverDef addVariableSaver(Graph graph) { } static { - TensorFlow.init(); + try { + // Ensure that TensorFlow native library and classes are ready to be used + Class.forName("org.tensorflow.TensorFlow"); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index c332fd7f1d1..2a4a21face3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -228,6 +228,11 @@ private static long[] shape(TF_Tensor handle) { private ByteDataBuffer buffer = null; static { - TensorFlow.init(); + try { + // Ensure that TensorFlow native library and classes are ready to be used + Class.forName("org.tensorflow.TensorFlow"); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 0974cc94a24..6992e5eee37 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -435,6 +435,11 @@ private static void validateTags(String[] tags) { } static { - TensorFlow.init(); + try { + // Ensure that TensorFlow native library and classes are ready to be used + Class.forName("org.tensorflow.TensorFlow"); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Server.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Server.java index e3b685889e1..2488a93c929 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Server.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Server.java @@ -178,6 +178,11 @@ private static void delete(TF_Server nativeHandle) { private int numJoining; static { - TensorFlow.init(); + try { + // Ensure that TensorFlow native library and classes are ready to be used + Class.forName("org.tensorflow.TensorFlow"); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index 3dd93aa382e..de481d256a3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -33,11 +33,6 @@ /** Static utility methods describing the TensorFlow runtime. */ public final class TensorFlow { - /** Make sure all TensorFlow native libraries have been loaded properly */ - public static void init() { - // Do nothing, we'll let the class static initializer load the native library if needed - } - /** Returns the version of the underlying TensorFlow runtime. */ public static String version() { return TF_Version().getString(); @@ -112,7 +107,7 @@ private static OpList libraryOpList(TF_Library handle) { private TensorFlow() {} /** Load the TensorFlow runtime C library. */ - private static void initTensorFlow() { + static { try { NativeLibrary.load(); } catch (Exception e) { @@ -127,8 +122,4 @@ private static void initTensorFlow() { throw e; } } - - static { - initTensorFlow(); - } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java index a5d94ae4f28..48ee4f72bee 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java @@ -135,6 +135,11 @@ void writeNext(byte[] bytes) { private final TF_TString data; static { - TensorFlow.init(); // make sure TF library is loaded before working with `TF_TString` native objects + try { + // Ensure that TensorFlow native library and classes are ready to be used + Class.forName("org.tensorflow.TensorFlow"); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java index 7efeb93f0d8..c8182ec8d57 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java @@ -23,7 +23,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.nio.charset.StandardCharsets; + import org.bytedeco.javacpp.Pointer; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; @@ -39,16 +41,15 @@ public void createScalar() { assertEquals("Pretty vacant", tensor.getObject()); } - @Test - public void createrScalarLongerThan127() { - TString tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !"); - assertNotNull(tensor); - assertEquals(Shape.scalar(), tensor.shape()); - assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", tensor.getObject()); - } - + @Test + public void createrScalarLongerThan127() { + TString tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !"); + assertNotNull(tensor); + assertEquals(Shape.scalar(), tensor.shape()); + assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", tensor.getObject()); + } - @Test + @Test public void createVector() { TString tensor = TString.vectorOf("Pretty", "vacant"); assertNotNull(tensor); @@ -106,6 +107,7 @@ public void initializingTensorWithRawBytes() { } @Test + @Disabled // FIXME This test does not deterministically succeed, so skip it by default public void testNoLeaks() throws Exception { // warm up and try to get all JIT compilation done to stabilize memory usage... for (int i = 0; i < 1000; i++) {