Skip to content

Commit

Permalink
Keep weak references to eager resources in session (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
karllessard committed Mar 2, 2021
1 parent c4498eb commit f6024dd
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ class EagerOperation extends AbstractOperation {
this.name = name;
this.opHandle = opNativeHandle;
this.outputHandles = outputNativeHandles;
session.attach(opNativeHandle);
session.attach(outputNativeHandles);
this.outputTensors = new AtomicReferenceArray<>(outputNativeHandles.length);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,7 @@ final class EagerOperationBuilder implements OperationBuilder {
@Override
public EagerOperation build() {
TFE_TensorHandle[] tensorHandles = execute(opHandle, session);
EagerOperation operation =
new EagerOperation(session, opHandle, tensorHandles, type, name);
// Release our reference to the native op handle now that we transferred its
// ownership to the EagerOperation
session.detach(opHandle);
return operation;
return new EagerOperation(session, opHandle, tensorHandles, type, name);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.internal.WeakPointerScope;
import org.tensorflow.internal.c_api.TFE_Context;
import org.tensorflow.internal.c_api.TFE_ContextOptions;
import org.tensorflow.internal.c_api.TF_Status;
Expand Down Expand Up @@ -310,13 +311,45 @@ TFE_Context nativeHandle() {
return nativeHandle;
}

/**
* Attach the list of native resources to this eager session scope.
*
* <p>When the eager session is closed (i.e. by calling {@link #close()} explicitly or
* implicitly via try-with-resources), all native resources attached to the session will be
* released as well, unless so other references are {@link Pointer#retainReference() retaining}
* them.</p>
*
* <p>Attached resources can still be garbage collected though if their associated {@link Pointer}
* is no longer reachable in Java, independently of their reference count. Therefore, it is
* assumed that these resources are not required by the native library once the Java client no
* longer needs them.</p>
*
* <p>Attaching a resource already attached to this session will have no effect.</p>
*
* @param resources resources to attach to the session
*/
void attach(Pointer... resources) {
checkSession();
for (Pointer r : resources) {
nativeResources.attach(r);
}
}

/**
* Detach a list of resources from this eager session scope.
*
* <p>Detached native resources will prevent them to be automatically released when the session is
* closed.</p>
*
* <p>Note though that this method will decrement the reference count of each resources being
* detached, which may automatically released them if that count reaches 0. Therefore,
* invoking {@link Pointer#retainReference()} prior to this call on any resource that must remain
* valid after being detached might be required.</p>
*
* <p>Detaching a resource that is not attached to this session will have no effect.</p>
*
* @param resources resources to detach from the session
*/
void detach(Pointer... resources) {
checkSession();
for (Pointer r : resources) {
Expand All @@ -326,14 +359,12 @@ void detach(Pointer... resources) {

private static volatile EagerSession defaultSession = null;

private final PointerScope nativeResources;
private final WeakPointerScope nativeResources;
private TFE_Context nativeHandle;

private EagerSession(Options options) {
try (PointerScope scope = new PointerScope()) {
this.nativeResources = scope.extend();
this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);
}
this.nativeResources = new WeakPointerScope();
this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);
}

private void checkSession() {
Expand Down Expand Up @@ -363,7 +394,7 @@ private static TFE_Context allocate(boolean async, int devicePlacementPolicy, Co
TFE_ContextOptionsSetDevicePlacementPolicy(opts, devicePlacementPolicy);
TFE_Context context = TFE_NewContext(opts, status);
status.throwExceptionIfNotOK();
return context;
return context.retainReference();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package org.tensorflow.internal;

import java.util.Collections;
import java.util.Set;
import java.util.WeakHashMap;
import org.bytedeco.javacpp.Pointer;

/**
* A minimalist pointer scope only keeping weak references to its elements.
*
* <p>As opposed to {@link org.bytedeco.javacpp.PointerScope}, instances of this class will not
* prevent the garbage collector to free the memory of a pointer that is no longer reachable, even
* if it has been attached to the scope.</p>
*
* <p>When the scope is closed, all pointers that are still valid will be automatically deallocated
* while those already garbage-collected will be ignored.</p>
*/
public class WeakPointerScope implements AutoCloseable {

/**
* Attach a pointer to this scope.
*
* <p>Pointers attached to the scope will be automatically freed once the scope is closed, unless
* they have been already released by the garbage collector</p>
*
* <p>It this {@code pointer} was already attached to this scope, this method has no effect.</p>
*
* @param pointer pointer to attach
* @throws IllegalStateException if that scope has already been closed
*/
public void attach(Pointer pointer) {
checkScope();
if (pointers.add(pointer)) {
pointer.retainReference();
}
}

/**
* Detach a pointer from this scope.
*
* <p>Detaching a pointer from the scope will prevent its memory to be freed when closing the
* scope.</p>
*
* <p>If this {@code pointer} is not attached to this scope, this method has no effect.</p>
*
* @param pointer pointer to detach
* @throws IllegalStateException if that scope has already been closed
*/
public void detach(Pointer pointer) {
checkScope();
if (pointers.remove(pointer)) {
pointer.releaseReference();
}
}

@Override
public synchronized void close() {
checkScope();
pointers.forEach(Pointer::releaseReference);
pointers = null;
}

private Set<Pointer> pointers = Collections.newSetFromMap(new WeakHashMap<>());

private void checkScope() {
if (pointers == null) {
throw new IllegalStateException("Pointer scope has been closed");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ public class EagerOperationTest {
public void failToCreateIfSessionIsClosed() {
EagerSession session = EagerSession.create();
session.close();
try {
new EagerOperation(session, null, null, "Add", "add");
try (TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) {
EagerOperation op =
opBuilder(session, "Const", "OutputAttrs")
.setAttr("dtype", t.dataType())
.setAttr("value", t)
.build();
fail();
} catch (IllegalStateException e) {
// expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ public void cleanupResourceInBackground() {
sleep(50); // allow some time to the background thread for cleaning up resources

long before = Pointer.totalBytes();
s.detach(ref.retainReference());
ref = null;
System.gc();
sleep(50); // allow some time to the background thread for cleaning up resources
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package org.tensorflow.internal;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.Test;
import org.tensorflow.EagerSession;

public class WeakPointerScopeTest {

@Test
public void resourcesAttachedAreFreedOnScopeClose() {
Pointer pointer = new IntPointer(10L);
assertEquals(0, pointer.referenceCount());

try (WeakPointerScope scope = new WeakPointerScope()) {
scope.attach(pointer);
assertEquals(1, pointer.referenceCount());
}
assertTrue(pointer.isNull());
}

@Test
public void resourcesDetachedAreNotFreedOnScopeCloseWhenRetained() {
Pointer pointer = new IntPointer(10L);

try (WeakPointerScope scope = new WeakPointerScope()) {
scope.attach(pointer);
scope.detach(pointer.retainReference());
}
assertFalse(pointer.isNull());
assertEquals(1, pointer.referenceCount());
pointer.deallocate();
}

@Test
public void resourcesDetachedAreFreedWhenNotRetained() {
Pointer pointer = new IntPointer(10L);

try (WeakPointerScope scope = new WeakPointerScope()) {
scope.attach(pointer);

scope.detach(pointer);
assertTrue(pointer.isNull());
}
}

@Test
public void attachingResourceMoreThanOnceHasNoEffect() {
Pointer pointer = new IntPointer(10L);

try (WeakPointerScope scope = new WeakPointerScope()) {
scope.attach(pointer);
scope.attach(pointer);
assertEquals(1, pointer.referenceCount());

Pointer pointerCopy = new Pointer(pointer);
assertEquals(1, pointerCopy.referenceCount());
scope.attach(pointerCopy);
assertEquals(1, pointerCopy.referenceCount());
}
assertTrue(pointer.isNull());
}

@Test
public void detachingUnattachedResourceHasNoEffect() {
Pointer pointer = new IntPointer(10L);
pointer.retainReference();
assertEquals(1, pointer.referenceCount());

try (WeakPointerScope scope = new WeakPointerScope()) {
scope.detach(pointer);
assertEquals(1, pointer.referenceCount());
}
assertFalse(pointer.isNull());
pointer.deallocate();
}

@Test
public void operationOnClosedScopeFails() {
Pointer pointer = new IntPointer(10L);
WeakPointerScope scope = new WeakPointerScope();
scope.close();

assertThrows(IllegalStateException.class, () -> scope.attach(pointer));
assertThrows(IllegalStateException.class, () -> scope.detach(pointer));
assertThrows(IllegalStateException.class, () -> scope.close());

pointer.deallocate();
}

@Test
public void attachingResourceDoesNotPreventItToBeGarbageCollected() throws InterruptedException {
try (WeakPointerScope scope = new WeakPointerScope()) {
Pointer pointer = new IntPointer(10L);
scope.attach(pointer);
System.gc();
Thread.sleep(50);

long before = Pointer.totalBytes();
pointer = null;
System.gc();
Thread.sleep(50);
long after = Pointer.totalBytes();

assertEquals(4 * 10L, before - after);
}
}
}

0 comments on commit f6024dd

Please sign in to comment.