Skip to content

Commit

Permalink
Fix closing TrinoResultSet background thread
Browse files Browse the repository at this point in the history
ResultSet cannot be properly closed because the inner Thread cannot
be interrupted and stop data row iteration. That will lead to thread
and memory leaks on the client side. This patch uses FutureTask,
which is created by ThreadPoolExecutor, instead of CompletableFuture
to make sure `Thread.interrupt()` can be invoked as expected. And
for the case that interruption is not properly handled by the
underlying StatementClient, a status check is added to the loop
condition so that loop can terminate and thread can be released.
  • Loading branch information
xiacongling authored and findepi committed Sep 19, 2022
1 parent 6acfd82 commit 601a3a7
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 8 deletions.
53 changes: 45 additions & 8 deletions client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.jdbc;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Streams;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
Expand All @@ -29,9 +30,9 @@
import java.util.Optional;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.function.Consumer;
import java.util.stream.Stream;
Expand Down Expand Up @@ -144,27 +145,39 @@ private static <T> Iterator<T> flatten(Iterator<Iterable<T>> iterator, long maxR
return stream.iterator();
}

private static class AsyncIterator<T>
@VisibleForTesting
static class AsyncIterator<T>
extends AbstractIterator<T>
{
private static final int MAX_QUEUED_ROWS = 50_000;
private static final ExecutorService executorService = newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat("Trino JDBC worker-%s").setDaemon(true).build());

private final StatementClient client;
private final BlockingQueue<T> rowQueue = new ArrayBlockingQueue<>(MAX_QUEUED_ROWS);
private final BlockingQueue<T> rowQueue;
// Semaphore to indicate that some data is ready.
// Each permit represents a row of data (or that the underlying iterator is exhausted).
private final Semaphore semaphore = new Semaphore(0);
private final CompletableFuture<Void> future;
private final Future<?> future;
private volatile boolean cancelled;
private volatile boolean finished;

public AsyncIterator(Iterator<T> dataIterator, StatementClient client)
{
this(dataIterator, client, Optional.empty());
}

@VisibleForTesting
AsyncIterator(Iterator<T> dataIterator, StatementClient client, Optional<BlockingQueue<T>> queue)
{
requireNonNull(dataIterator, "dataIterator is null");
this.client = client;
this.future = CompletableFuture.runAsync(() -> {
this.rowQueue = queue.orElseGet(() -> new ArrayBlockingQueue<>(MAX_QUEUED_ROWS));
this.cancelled = false;
this.finished = false;
this.future = executorService.submit(() -> {
try {
while (dataIterator.hasNext()) {
while (!cancelled && dataIterator.hasNext()) {
rowQueue.put(dataIterator.next());
semaphore.release();
}
Expand All @@ -174,22 +187,46 @@ public AsyncIterator(Iterator<T> dataIterator, StatementClient client)
}
finally {
semaphore.release();
finished = true;
}
}, executorService);
});
}

public void cancel()
{
cancelled = true;
future.cancel(true);
cleanup();
}

public void interrupt(InterruptedException e)
{
client.close();
cleanup();
Thread.currentThread().interrupt();
throw new RuntimeException(new SQLException("ResultSet thread was interrupted", e));
}

private void cleanup()
{
// When thread interruption is mis-handled by underlying implementation of `client`, the thread which
// is working for `future` may be blocked by `rowQueue.put` (`rowQueue` is full) and will never finish
// its work. It is necessary to close `client` and drain `rowQueue` to avoid such leaks.
client.close();
rowQueue.clear();
}

@VisibleForTesting
Future<?> getFuture()
{
return future;
}

@VisibleForTesting
boolean isBackgroundThreadFinished()
{
return finished;
}

@Override
protected T computeNext()
{
Expand Down
Expand Up @@ -24,6 +24,9 @@

import static java.lang.String.format;

/**
* An integration test for JDBC client interacting with Trino server.
*/
public class TestJdbcResultSet
extends BaseTestJdbcResultSet
{
Expand Down
272 changes: 272 additions & 0 deletions client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java
@@ -0,0 +1,272 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.jdbc;

import com.google.common.collect.ImmutableList;
import io.trino.client.ClientSelectedRole;
import io.trino.client.QueryData;
import io.trino.client.QueryStatusInfo;
import io.trino.client.StatementClient;
import io.trino.client.StatementStats;
import org.testng.annotations.Test;

import java.time.ZoneId;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.testng.Assert.assertTrue;

/**
* A unit test for {@link TrinoResultSet}.
*
* @see TestJdbcResultSet for an integration test.
*/
public class TestTrinoResultSet
{
@Test(timeOut = 10000)
public void testIteratorCancelWhenQueueNotFull()
throws Exception
{
AtomicReference<Thread> thread = new AtomicReference<>();
CountDownLatch interruptedButSwallowedLatch = new CountDownLatch(1);
MockAsyncIterator<Iterable<List<Object>>> iterator = new MockAsyncIterator<>(
new Iterator<Iterable<List<Object>>>()
{
@Override
public boolean hasNext()
{
return true;
}

@Override
public Iterable<List<Object>> next()
{
thread.compareAndSet(null, Thread.currentThread());
try {
TimeUnit.MILLISECONDS.sleep(1000);
}
catch (InterruptedException e) {
interruptedButSwallowedLatch.countDown();
}
return ImmutableList.of((ImmutableList.of(new Object())));
}
},
new ArrayBlockingQueue<>(100));

while (thread.get() == null || thread.get().getState() != Thread.State.TIMED_WAITING) {
// wait for thread being waiting
}
iterator.cancel();
while (!iterator.getFuture().isDone() || !iterator.isBackgroundThreadFinished()) {
TimeUnit.MILLISECONDS.sleep(10);
}
boolean interruptedButSwallowed = interruptedButSwallowedLatch.await(5000, TimeUnit.MILLISECONDS);
assertTrue(interruptedButSwallowed);
}

@Test(timeOut = 10000)
public void testIteratorCancelWhenQueueIsFull()
throws Exception
{
BlockingQueue<Iterable<List<Object>>> queue = new ArrayBlockingQueue<>(1);
queue.put(ImmutableList.of());
// queue is full at the beginning
AtomicReference<Thread> thread = new AtomicReference<>();
MockAsyncIterator<Iterable<List<Object>>> iterator = new MockAsyncIterator<>(
new Iterator<Iterable<List<Object>>>()
{
@Override
public boolean hasNext()
{
return true;
}

@Override
public Iterable<List<Object>> next()
{
thread.compareAndSet(null, Thread.currentThread());
return ImmutableList.of((ImmutableList.of(new Object())));
}
},
queue);

while (thread.get() == null || thread.get().getState() != Thread.State.WAITING) {
// wait for thread being waiting (for queue being not full)
TimeUnit.MILLISECONDS.sleep(10);
}
iterator.cancel();
while (!iterator.isBackgroundThreadFinished()) {
TimeUnit.MILLISECONDS.sleep(10);
}
}

private static class MockAsyncIterator<T>
extends TrinoResultSet.AsyncIterator<T>
{
public MockAsyncIterator(Iterator<T> dataIterator, BlockingQueue<T> queue)
{
super(
dataIterator,
new StatementClient()
{
@Override
public String getQuery()
{
throw new UnsupportedOperationException();
}

@Override
public ZoneId getTimeZone()
{
throw new UnsupportedOperationException();
}

@Override
public boolean isRunning()
{
throw new UnsupportedOperationException();
}

@Override
public boolean isClientAborted()
{
throw new UnsupportedOperationException();
}

@Override
public boolean isClientError()
{
throw new UnsupportedOperationException();
}

@Override
public boolean isFinished()
{
throw new UnsupportedOperationException();
}

@Override
public StatementStats getStats()
{
throw new UnsupportedOperationException();
}

@Override
public QueryStatusInfo currentStatusInfo()
{
throw new UnsupportedOperationException();
}

@Override
public QueryData currentData()
{
throw new UnsupportedOperationException();
}

@Override
public QueryStatusInfo finalStatusInfo()
{
throw new UnsupportedOperationException();
}

@Override
public Optional<String> getSetCatalog()
{
throw new UnsupportedOperationException();
}

@Override
public Optional<String> getSetSchema()
{
throw new UnsupportedOperationException();
}

@Override
public Optional<String> getSetPath()
{
throw new UnsupportedOperationException();
}

@Override
public Map<String, String> getSetSessionProperties()
{
throw new UnsupportedOperationException();
}

@Override
public Set<String> getResetSessionProperties()
{
throw new UnsupportedOperationException();
}

@Override
public Map<String, ClientSelectedRole> getSetRoles()
{
throw new UnsupportedOperationException();
}

@Override
public Map<String, String> getAddedPreparedStatements()
{
throw new UnsupportedOperationException();
}

@Override
public Set<String> getDeallocatedPreparedStatements()
{
throw new UnsupportedOperationException();
}

@Override
public String getStartedTransactionId()
{
throw new UnsupportedOperationException();
}

@Override
public boolean isClearTransactionId()
{
throw new UnsupportedOperationException();
}

@Override
public boolean advance()
{
throw new UnsupportedOperationException();
}

@Override
public void cancelLeafStage()
{
throw new UnsupportedOperationException();
}

@Override
public void close()
{
// do nothing
}
},
Optional.of(queue));
}
}
}

0 comments on commit 601a3a7

Please sign in to comment.