From ea7e61425903cdfffa7a56bee6c723e5c017d0b1 Mon Sep 17 00:00:00 2001 From: klessard Date: Sun, 28 Feb 2021 22:43:53 -0500 Subject: [PATCH 1/2] Keep weak references to eager resources in session --- .../java/org/tensorflow/EagerOperation.java | 2 - .../org/tensorflow/EagerOperationBuilder.java | 7 +-- .../java/org/tensorflow/EagerSession.java | 11 ++-- .../tensorflow/internal/WeakPointerScope.java | 58 +++++++++++++++++++ .../org/tensorflow/EagerOperationTest.java | 8 ++- .../java/org/tensorflow/EagerSessionTest.java | 1 - 6 files changed, 70 insertions(+), 17 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/WeakPointerScope.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index a5c2df84026..9f87fd8b95e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -53,8 +53,6 @@ class EagerOperation extends AbstractOperation { this.name = name; this.opHandle = opNativeHandle; this.outputHandles = outputNativeHandles; - session.attach(opNativeHandle); - session.attach(outputNativeHandles); this.outputTensors = new AtomicReferenceArray<>(outputNativeHandles.length); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index a865300bc5a..f1dd6216a79 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -65,12 +65,7 @@ final class EagerOperationBuilder implements OperationBuilder { @Override public EagerOperation build() { TFE_TensorHandle[] tensorHandles = execute(opHandle, session); - EagerOperation operation = - new EagerOperation(session, opHandle, tensorHandles, type, name); - // Release our reference to the native op handle now that we transferred its - // ownership to the EagerOperation - session.detach(opHandle); - return operation; + return new EagerOperation(session, opHandle, tensorHandles, type, name); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 75bc12b5a6c..c58a994322d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -24,6 +24,7 @@ import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerScope; +import org.tensorflow.internal.WeakPointerScope; import org.tensorflow.internal.c_api.TFE_Context; import org.tensorflow.internal.c_api.TFE_ContextOptions; import org.tensorflow.internal.c_api.TF_Status; @@ -326,14 +327,12 @@ void detach(Pointer... resources) { private static volatile EagerSession defaultSession = null; - private final PointerScope nativeResources; + private final WeakPointerScope nativeResources; private TFE_Context nativeHandle; private EagerSession(Options options) { - try (PointerScope scope = new PointerScope()) { - this.nativeResources = scope.extend(); - this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); - } + this.nativeResources = new WeakPointerScope(); + this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); } private void checkSession() { @@ -363,7 +362,7 @@ private static TFE_Context allocate(boolean async, int devicePlacementPolicy, Co TFE_ContextOptionsSetDevicePlacementPolicy(opts, devicePlacementPolicy); TFE_Context context = TFE_NewContext(opts, status); status.throwExceptionIfNotOK(); - return context; + return context.retainReference(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/WeakPointerScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/WeakPointerScope.java new file mode 100644 index 00000000000..4444e14d042 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/WeakPointerScope.java @@ -0,0 +1,58 @@ +package org.tensorflow.internal; + +import java.lang.ref.WeakReference; +import java.util.Collections; +import java.util.LinkedList; +import java.util.Set; +import java.util.WeakHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.bytedeco.javacpp.Pointer; + +/** + * A minimalist pointer scope only keeping weak references to its elements. + * + *

As opposed to {@link org.bytedeco.javacpp.PointerScope}, instances of this class will not + * prevent the garbage collector to free the memory of a pointer that is no longer reachable, even + * if it has been attached to the scope.

+ * + *

When the scope is closed, all pointers that are still valid will be automatically deallocated + * while those already garbage-collected will be ignored.

+ */ +public class WeakPointerScope implements AutoCloseable { + + /** + * Attach a pointer to this scope. + * + *

Pointers attached to the scope will be automatically freed once the scope is closed, unless + * they have been already released by the garbage collector

+ * + * @param pointer pointer to attach + */ + public void attach(Pointer pointer) { + pointers.add(pointer.retainReference()); + } + + /** + * Detach a pointer from this scope. + * + *

Detaching a pointer from the scope will prevent its memory to be freed when closing the + * scope.

+ * + *

If this {@code pointer} is not attached to this scope, this method has no effect.

+ * + * @param pointer pointer to detach + */ + public void detach(Pointer pointer) { + if (pointers.remove(pointer)) { + pointer.releaseReference(); + } + } + + @Override + public synchronized void close() { + pointers.forEach(Pointer::releaseReference); + pointers.clear(); + } + + private final Set pointers = Collections.newSetFromMap(new WeakHashMap<>()); +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java index 2920fbdf59f..38714b86599 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java @@ -35,8 +35,12 @@ public class EagerOperationTest { public void failToCreateIfSessionIsClosed() { EagerSession session = EagerSession.create(); session.close(); - try { - new EagerOperation(session, null, null, "Add", "add"); + try (TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) { + EagerOperation op = + opBuilder(session, "Const", "OutputAttrs") + .setAttr("dtype", t.dataType()) + .setAttr("value", t) + .build(); fail(); } catch (IllegalStateException e) { // expected diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java index 7ac54213a0b..77325d50dcc 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java @@ -59,7 +59,6 @@ public void cleanupResourceInBackground() { sleep(50); // allow some time to the background thread for cleaning up resources long before = Pointer.totalBytes(); - s.detach(ref.retainReference()); ref = null; System.gc(); sleep(50); // allow some time to the background thread for cleaning up resources From 76e24d02b959ca1fb8e7dfeba8631f4bb44677b4 Mon Sep 17 00:00:00 2001 From: klessard Date: Mon, 1 Mar 2021 21:39:18 -0500 Subject: [PATCH 2/2] Add documentation and unit test --- .../java/org/tensorflow/EagerSession.java | 32 +++++ .../tensorflow/internal/WeakPointerScope.java | 24 +++- .../internal/WeakPointerScopeTest.java | 114 ++++++++++++++++++ 3 files changed, 164 insertions(+), 6 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/internal/WeakPointerScopeTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index c58a994322d..8e7465388a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -311,6 +311,23 @@ TFE_Context nativeHandle() { return nativeHandle; } + /** + * Attach the list of native resources to this eager session scope. + * + *

When the eager session is closed (i.e. by calling {@link #close()} explicitly or + * implicitly via try-with-resources), all native resources attached to the session will be + * released as well, unless so other references are {@link Pointer#retainReference() retaining} + * them.

+ * + *

Attached resources can still be garbage collected though if their associated {@link Pointer} + * is no longer reachable in Java, independently of their reference count. Therefore, it is + * assumed that these resources are not required by the native library once the Java client no + * longer needs them.

+ * + *

Attaching a resource already attached to this session will have no effect.

+ * + * @param resources resources to attach to the session + */ void attach(Pointer... resources) { checkSession(); for (Pointer r : resources) { @@ -318,6 +335,21 @@ void attach(Pointer... resources) { } } + /** + * Detach a list of resources from this eager session scope. + * + *

Detached native resources will prevent them to be automatically released when the session is + * closed.

+ * + *

Note though that this method will decrement the reference count of each resources being + * detached, which may automatically released them if that count reaches 0. Therefore, + * invoking {@link Pointer#retainReference()} prior to this call on any resource that must remain + * valid after being detached might be required.

+ * + *

Detaching a resource that is not attached to this session will have no effect.

+ * + * @param resources resources to detach from the session + */ void detach(Pointer... resources) { checkSession(); for (Pointer r : resources) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/WeakPointerScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/WeakPointerScope.java index 4444e14d042..f12e97c2702 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/WeakPointerScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/WeakPointerScope.java @@ -1,11 +1,8 @@ package org.tensorflow.internal; -import java.lang.ref.WeakReference; import java.util.Collections; -import java.util.LinkedList; import java.util.Set; import java.util.WeakHashMap; -import java.util.concurrent.ConcurrentLinkedQueue; import org.bytedeco.javacpp.Pointer; /** @@ -26,10 +23,16 @@ public class WeakPointerScope implements AutoCloseable { *

Pointers attached to the scope will be automatically freed once the scope is closed, unless * they have been already released by the garbage collector

* + *

It this {@code pointer} was already attached to this scope, this method has no effect.

+ * * @param pointer pointer to attach + * @throws IllegalStateException if that scope has already been closed */ public void attach(Pointer pointer) { - pointers.add(pointer.retainReference()); + checkScope(); + if (pointers.add(pointer)) { + pointer.retainReference(); + } } /** @@ -41,8 +44,10 @@ public void attach(Pointer pointer) { *

If this {@code pointer} is not attached to this scope, this method has no effect.

* * @param pointer pointer to detach + * @throws IllegalStateException if that scope has already been closed */ public void detach(Pointer pointer) { + checkScope(); if (pointers.remove(pointer)) { pointer.releaseReference(); } @@ -50,9 +55,16 @@ public void detach(Pointer pointer) { @Override public synchronized void close() { + checkScope(); pointers.forEach(Pointer::releaseReference); - pointers.clear(); + pointers = null; } - private final Set pointers = Collections.newSetFromMap(new WeakHashMap<>()); + private Set pointers = Collections.newSetFromMap(new WeakHashMap<>()); + + private void checkScope() { + if (pointers == null) { + throw new IllegalStateException("Pointer scope has been closed"); + } + } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/internal/WeakPointerScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/internal/WeakPointerScopeTest.java new file mode 100644 index 00000000000..815a1200c89 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/internal/WeakPointerScopeTest.java @@ -0,0 +1,114 @@ +package org.tensorflow.internal; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.bytedeco.javacpp.IntPointer; +import org.bytedeco.javacpp.Pointer; +import org.junit.jupiter.api.Test; +import org.tensorflow.EagerSession; + +public class WeakPointerScopeTest { + + @Test + public void resourcesAttachedAreFreedOnScopeClose() { + Pointer pointer = new IntPointer(10L); + assertEquals(0, pointer.referenceCount()); + + try (WeakPointerScope scope = new WeakPointerScope()) { + scope.attach(pointer); + assertEquals(1, pointer.referenceCount()); + } + assertTrue(pointer.isNull()); + } + + @Test + public void resourcesDetachedAreNotFreedOnScopeCloseWhenRetained() { + Pointer pointer = new IntPointer(10L); + + try (WeakPointerScope scope = new WeakPointerScope()) { + scope.attach(pointer); + scope.detach(pointer.retainReference()); + } + assertFalse(pointer.isNull()); + assertEquals(1, pointer.referenceCount()); + pointer.deallocate(); + } + + @Test + public void resourcesDetachedAreFreedWhenNotRetained() { + Pointer pointer = new IntPointer(10L); + + try (WeakPointerScope scope = new WeakPointerScope()) { + scope.attach(pointer); + + scope.detach(pointer); + assertTrue(pointer.isNull()); + } + } + + @Test + public void attachingResourceMoreThanOnceHasNoEffect() { + Pointer pointer = new IntPointer(10L); + + try (WeakPointerScope scope = new WeakPointerScope()) { + scope.attach(pointer); + scope.attach(pointer); + assertEquals(1, pointer.referenceCount()); + + Pointer pointerCopy = new Pointer(pointer); + assertEquals(1, pointerCopy.referenceCount()); + scope.attach(pointerCopy); + assertEquals(1, pointerCopy.referenceCount()); + } + assertTrue(pointer.isNull()); + } + + @Test + public void detachingUnattachedResourceHasNoEffect() { + Pointer pointer = new IntPointer(10L); + pointer.retainReference(); + assertEquals(1, pointer.referenceCount()); + + try (WeakPointerScope scope = new WeakPointerScope()) { + scope.detach(pointer); + assertEquals(1, pointer.referenceCount()); + } + assertFalse(pointer.isNull()); + pointer.deallocate(); + } + + @Test + public void operationOnClosedScopeFails() { + Pointer pointer = new IntPointer(10L); + WeakPointerScope scope = new WeakPointerScope(); + scope.close(); + + assertThrows(IllegalStateException.class, () -> scope.attach(pointer)); + assertThrows(IllegalStateException.class, () -> scope.detach(pointer)); + assertThrows(IllegalStateException.class, () -> scope.close()); + + pointer.deallocate(); + } + + @Test + public void attachingResourceDoesNotPreventItToBeGarbageCollected() throws InterruptedException { + try (WeakPointerScope scope = new WeakPointerScope()) { + Pointer pointer = new IntPointer(10L); + scope.attach(pointer); + System.gc(); + Thread.sleep(50); + + long before = Pointer.totalBytes(); + pointer = null; + System.gc(); + Thread.sleep(50); + long after = Pointer.totalBytes(); + + assertEquals(4 * 10L, before - after); + } + } +}