Skip to content

Commit

Permalink
Add limit pushdown to memory connector
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Mar 12, 2019
1 parent 5fcc9e0 commit e63051b
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 29 deletions.
Expand Up @@ -32,6 +32,7 @@
import io.prestosql.spi.connector.ConnectorTableMetadata;
import io.prestosql.spi.connector.ConnectorTableProperties;
import io.prestosql.spi.connector.ConnectorViewDefinition;
import io.prestosql.spi.connector.LimitApplicationResult;
import io.prestosql.spi.connector.SchemaNotFoundException;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.SchemaTablePrefix;
Expand All @@ -47,6 +48,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;

Expand Down Expand Up @@ -348,4 +350,16 @@ public List<MemoryDataFragment> getDataFragments(long tableId)
{
return ImmutableList.copyOf(tables.get(tableId).getDataFragments().values());
}

@Override
public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(ConnectorTableHandle handle, long limit)
{
MemoryTableHandle table = (MemoryTableHandle) handle;

if (!table.getLimit().isPresent() || limit < table.getLimit().getAsLong()) {
table = new MemoryTableHandle(table.getId(), OptionalLong.of(limit));
}

return Optional.of(new LimitApplicationResult<>(table, true));
}
}
Expand Up @@ -61,7 +61,8 @@ public ConnectorPageSource createPageSource(
partNumber,
totalParts,
columnIndexes,
expectedRows);
expectedRows,
memorySplit.getLimit());

return new FixedPageSource(pages);
}
Expand Down
Expand Up @@ -28,6 +28,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.OptionalLong;
import java.util.Set;

import static io.prestosql.plugin.memory.MemoryErrorCode.MEMORY_LIMIT_EXCEEDED;
Expand Down Expand Up @@ -80,7 +81,8 @@ public synchronized List<Page> getPages(
int partNumber,
int totalParts,
List<Integer> columnIndexes,
long expectedRows)
long expectedRows,
OptionalLong limit)
{
if (!contains(tableId)) {
throw new PrestoException(MISSING_DATA, "Failed to find table on a worker.");
Expand All @@ -93,8 +95,17 @@ public synchronized List<Page> getPages(

ImmutableList.Builder<Page> partitionedPages = ImmutableList.builder();

for (int i = partNumber; i < tableData.getPages().size(); i += totalParts) {
partitionedPages.add(getColumns(tableData.getPages().get(i), columnIndexes));
boolean done = false;
long totalRows = 0;
for (int i = partNumber; i < tableData.getPages().size() && !done; i += totalParts) {
Page page = tableData.getPages().get(i);

totalRows += page.getPositionCount();
if (limit.isPresent() && totalRows > limit.getAsLong()) {
page = page.getRegion(0, (int) (page.getPositionCount() - (totalRows - limit.getAsLong())));
done = true;
}
partitionedPages.add(getColumns(page, columnIndexes));
}

return partitionedPages.build();
Expand Down
Expand Up @@ -21,6 +21,7 @@

import java.util.List;
import java.util.Objects;
import java.util.OptionalLong;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkState;
Expand All @@ -34,14 +35,16 @@ public class MemorySplit
private final int partNumber; // part of the pages on one worker that this splits is responsible
private final HostAddress address;
private final long expectedRows;
private final OptionalLong limit;

@JsonCreator
public MemorySplit(
@JsonProperty("table") long table,
@JsonProperty("partNumber") int partNumber,
@JsonProperty("totalPartsPerWorker") int totalPartsPerWorker,
@JsonProperty("address") HostAddress address,
@JsonProperty("expectedRows") long expectedRows)
@JsonProperty("expectedRows") long expectedRows,
@JsonProperty("limit") OptionalLong limit)
{
checkState(partNumber >= 0, "partNumber must be >= 0");
checkState(totalPartsPerWorker >= 1, "totalPartsPerWorker must be >= 1");
Expand All @@ -52,6 +55,7 @@ public MemorySplit(
this.totalPartsPerWorker = totalPartsPerWorker;
this.address = requireNonNull(address, "address is null");
this.expectedRows = expectedRows;
this.limit = limit;
}

@JsonProperty
Expand Down Expand Up @@ -102,6 +106,12 @@ public long getExpectedRows()
return expectedRows;
}

@JsonProperty
public OptionalLong getLimit()
{
return limit;
}

@Override
public boolean equals(Object obj)
{
Expand Down
Expand Up @@ -25,6 +25,7 @@
import javax.inject.Inject;

import java.util.List;
import java.util.OptionalLong;

public final class MemorySplitManager
implements ConnectorSplitManager
Expand All @@ -46,16 +47,22 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHand

List<MemoryDataFragment> dataFragments = metadata.getDataFragments(table.getId());

int totalRows = 0;

ImmutableList.Builder<ConnectorSplit> splits = ImmutableList.builder();

for (MemoryDataFragment dataFragment : dataFragments) {
long rows = dataFragment.getRows();
totalRows += rows;

if (table.getLimit().isPresent() && totalRows > table.getLimit().getAsLong()) {
rows -= totalRows - table.getLimit().getAsLong();
splits.add(new MemorySplit(table.getId(), 0, 1, dataFragment.getHostAddress(), rows, OptionalLong.of(rows)));
break;
}

for (int i = 0; i < splitsPerNode; i++) {
splits.add(
new MemorySplit(
table.getId(),
i,
splitsPerNode,
dataFragment.getHostAddress(),
dataFragment.getRows()));
splits.add(new MemorySplit(table.getId(), i, splitsPerNode, dataFragment.getHostAddress(), rows, OptionalLong.empty()));
}
}
return new FixedSplitSource(splits.build());
Expand Down
Expand Up @@ -15,20 +15,31 @@

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.primitives.Longs;
import io.prestosql.spi.connector.ConnectorTableHandle;

import java.util.Objects;
import java.util.OptionalLong;

import static java.util.Objects.requireNonNull;

public final class MemoryTableHandle
implements ConnectorTableHandle
{
private final long id;
private final OptionalLong limit;

public MemoryTableHandle(long id)
{
this(id, OptionalLong.empty());
}

@JsonCreator
public MemoryTableHandle(@JsonProperty("id") long id)
public MemoryTableHandle(
@JsonProperty("id") long id,
@JsonProperty("limit") OptionalLong limit)
{
this.id = id;
this.limit = requireNonNull(limit, "limit is null");
}

@JsonProperty
Expand All @@ -37,28 +48,38 @@ public long getId()
return id;
}

@Override
public int hashCode()
@JsonProperty
public OptionalLong getLimit()
{
return Longs.hashCode(id);
return limit;
}

@Override
public boolean equals(Object obj)
public boolean equals(Object o)
{
if (this == obj) {
if (this == o) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
if (o == null || getClass() != o.getClass()) {
return false;
}
MemoryTableHandle other = (MemoryTableHandle) obj;
return Objects.equals(this.getId(), other.getId());
MemoryTableHandle that = (MemoryTableHandle) o;
return id == that.id &&
limit.equals(that.limit);
}

@Override
public int hashCode()
{
return Objects.hash(id, limit);
}

@Override
public String toString()
{
return Long.toString(id);
StringBuilder builder = new StringBuilder();
builder.append(id);
limit.ifPresent(value -> builder.append("(limit:" + value + ")"));
return builder.toString();
}
}
Expand Up @@ -28,6 +28,8 @@
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import java.util.OptionalLong;

import static io.prestosql.spi.type.BigintType.BIGINT;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
Expand All @@ -53,36 +55,36 @@ public void setUp()
public void testCreateEmptyTable()
{
createTable(0L, 0L);
assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 0), ImmutableList.of());
assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 0, OptionalLong.empty()), ImmutableList.of());
}

@Test
public void testInsertPage()
{
createTable(0L, 0L);
insertToTable(0L, 0L);
assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), POSITIONS_PER_PAGE).size(), 1);
assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), POSITIONS_PER_PAGE, OptionalLong.empty()).size(), 1);
}

@Test
public void testInsertPageWithoutCreate()
{
insertToTable(0L, 0L);
assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), POSITIONS_PER_PAGE).size(), 1);
assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), POSITIONS_PER_PAGE, OptionalLong.empty()).size(), 1);
}

@Test(expectedExceptions = PrestoException.class)
public void testReadFromUnknownTable()
{
pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 0);
pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 0, OptionalLong.empty());
}

@Test(expectedExceptions = PrestoException.class)
public void testTryToReadFromEmptyTable()
{
createTable(0L, 0L);
assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 0), ImmutableList.of());
pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 42);
assertEquals(pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 0, OptionalLong.empty()), ImmutableList.of());
pagesStore.getPages(0L, 0, 1, ImmutableList.of(0), 42, OptionalLong.empty());
}

@Test
Expand Down

0 comments on commit e63051b

Please sign in to comment.