Skip to content

Commit

Permalink
Implement CREATE OR REPLACE TABLE for delta lake connector
Browse files Browse the repository at this point in the history
  • Loading branch information
Praveen2112 committed Dec 13, 2023
1 parent e729e6c commit b2a6b28
Show file tree
Hide file tree
Showing 12 changed files with 904 additions and 42 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.plugin.deltalake.DeltaLakeColumnType.PARTITION_KEY;
Expand All @@ -43,6 +44,8 @@ public class DeltaLakeOutputTableHandle
private final ColumnMappingMode columnMappingMode;
private final OptionalInt maxColumnId;
private final String schemaString;
private final boolean replace;
private final OptionalLong readVersion;
private final ProtocolEntry protocolEntry;

@JsonCreator
Expand All @@ -58,6 +61,8 @@ public DeltaLakeOutputTableHandle(
@JsonProperty("schemaString") String schemaString,
@JsonProperty("columnMappingMode") ColumnMappingMode columnMappingMode,
@JsonProperty("maxColumnId") OptionalInt maxColumnId,
@JsonProperty("replace") boolean replace,
@JsonProperty("readVersion") OptionalLong readVersion,
@JsonProperty("protocolEntry") ProtocolEntry protocolEntry)
{
this.schemaName = requireNonNull(schemaName, "schemaName is null");
Expand All @@ -71,6 +76,8 @@ public DeltaLakeOutputTableHandle(
this.schemaString = requireNonNull(schemaString, "schemaString is null");
this.columnMappingMode = requireNonNull(columnMappingMode, "columnMappingMode is null");
this.maxColumnId = requireNonNull(maxColumnId, "maxColumnId is null");
this.replace = replace;
this.readVersion = requireNonNull(readVersion, "readVersion is null");
this.protocolEntry = requireNonNull(protocolEntry, "protocolEntry is null");
}

Expand Down Expand Up @@ -149,6 +156,18 @@ public OptionalInt getMaxColumnId()
return maxColumnId;
}

@JsonProperty
public boolean isReplace()
{
return replace;
}

@JsonProperty
public OptionalLong getReadVersion()
{
return readVersion;
}

@JsonProperty
public ProtocolEntry getProtocolEntry()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ public interface DeltaLakeMetastore

void createTable(ConnectorSession session, Table table, PrincipalPrivileges principalPrivileges);

void replaceTable(ConnectorSession session, Table table, PrincipalPrivileges principalPrivileges);

void dropTable(ConnectorSession session, SchemaTableName schemaTableName, String tableLocation, boolean deleteData);

void renameTable(ConnectorSession session, SchemaTableName from, SchemaTableName to);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ public void createTable(ConnectorSession session, Table table, PrincipalPrivileg
delegate.createTable(table, principalPrivileges);
}

@Override
public void replaceTable(ConnectorSession session, Table table, PrincipalPrivileges principalPrivileges)
{
delegate.replaceTable(table.getDatabaseName(), table.getTableName(), table, principalPrivileges);
}

@Override
public void dropTable(ConnectorSession session, SchemaTableName schemaTableName, String tableLocation, boolean deleteData)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.Futures;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.Session;
Expand All @@ -41,15 +42,22 @@
import io.trino.tpch.TpchTable;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.Timeout;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.BiConsumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.IntStream;

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand All @@ -59,13 +67,17 @@
import static io.trino.SystemSessionProperties.ENABLE_DYNAMIC_FILTERING;
import static io.trino.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static io.trino.plugin.base.util.Closables.closeAllSuppress;
import static io.trino.plugin.deltalake.DeltaLakeMetadata.CREATE_OR_REPLACE_TABLE_AS_OPERATION;
import static io.trino.plugin.deltalake.DeltaLakeMetadata.CREATE_OR_REPLACE_TABLE_OPERATION;
import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.DELTA_CATALOG;
import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.createDockerizedDeltaLakeQueryRunner;
import static io.trino.plugin.deltalake.DeltaLakeSessionProperties.EXTENDED_STATISTICS_COLLECT_ON_WRITE;
import static io.trino.plugin.deltalake.TestingDeltaLakeUtils.getConnectorService;
import static io.trino.plugin.deltalake.TestingDeltaLakeUtils.getTableActiveFiles;
import static io.trino.plugin.deltalake.transactionlog.TransactionLogUtil.TRANSACTION_LOG_DIRECTORY;
import static io.trino.plugin.hive.TestingThriftHiveMetastoreBuilder.testingThriftHiveMetastoreBuilder;
import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder;
import static io.trino.testing.QueryAssertions.getTrinoExceptionCause;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.DELETE_TABLE;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_TABLE;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.INSERT_TABLE;
Expand All @@ -79,6 +91,7 @@
import static io.trino.tpch.TpchTable.ORDERS;
import static java.lang.String.format;
import static java.util.Comparator.comparing;
import static java.util.concurrent.Executors.newFixedThreadPool;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -2231,6 +2244,141 @@ public void testPartitionFilterIncluded()
}
}

@Test
public void testCreateOrReplaceTable()
{
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_table", " AS SELECT BIGINT '42' a, DOUBLE '-38.5' b")) {
assertThat(query("SELECT CAST(a AS bigint), b FROM " + table.getName()))
.matches("VALUES (BIGINT '42', -385e-1)");

assertUpdate("CREATE OR REPLACE TABLE %s (a bigint, b double)".formatted(table.getName()));
assertQueryReturnsEmptyResult("SELECT * FROM " + table.getName());

assertThat(getTableVersion(table.getName())).isEqualTo(1);
assertTableOperation(table.getName(), 1, CREATE_OR_REPLACE_TABLE_OPERATION);
}
}

@Test
public void testCreateOrReplaceTableAs()
{
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_table", " AS SELECT BIGINT '42' a, DOUBLE '-38.5' b")) {
assertThat(query("SELECT CAST(a AS bigint), b FROM " + table.getName()))
.matches("VALUES (BIGINT '42', -385e-1)");

assertUpdate("CREATE OR REPLACE TABLE %s AS SELECT BIGINT '-53' a, DOUBLE '49.6' b".formatted(table.getName()), 1);
assertThat(query("SELECT CAST(a AS bigint), b FROM " + table.getName()))
.matches("VALUES (BIGINT '-53', 496e-1)");

assertThat(getTableVersion(table.getName())).isEqualTo(1);
assertTableOperation(table.getName(), 1, CREATE_OR_REPLACE_TABLE_AS_OPERATION);
}
}

@Test
public void testCreateOrReplaceTableChangeColumnNamesAndTypes()
{
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_table", " AS SELECT BIGINT '42' a, DOUBLE '-38.5' b")) {
assertThat(query("SELECT CAST(a AS bigint), b FROM " + table.getName()))
.matches("VALUES (BIGINT '42', -385e-1)");

assertUpdate("CREATE OR REPLACE TABLE " + table.getName() + " AS SELECT VARCHAR 'test' c, VARCHAR 'test2' d", 1);
assertThat(query("SELECT c, d FROM " + table.getName()))
.matches("VALUES (VARCHAR 'test', VARCHAR 'test2')");

assertThat(getTableVersion(table.getName())).isEqualTo(1);
assertTableOperation(table.getName(), 1, CREATE_OR_REPLACE_TABLE_AS_OPERATION);
}
}

@Test
@Timeout(60)
@RepeatedTest(4)
// Test fromm BaseConnectorTest
public void testCreateOrReplaceTableConcurrently()
throws Exception
{
int threads = 4;
int numOfCreateOrReplaceStatements = 4;
int numOfReads = 16;
CyclicBarrier barrier = new CyclicBarrier(threads + 1);
ExecutorService executor = newFixedThreadPool(threads + 1);
List<Future<?>> futures = new ArrayList<>();
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_create_or_replace", "(col integer)")) {
String tableName = table.getName();

getQueryRunner().execute("CREATE OR REPLACE TABLE " + tableName + " AS SELECT 1 a");
assertThat(query("SELECT * FROM " + tableName)).matches("VALUES 1");

/// One thread submits some CREATE OR REPLACE statements
futures.add(executor.submit(() -> {
barrier.await(30, SECONDS);
IntStream.range(0, numOfCreateOrReplaceStatements).forEach(index -> {
try {
getQueryRunner().execute("CREATE OR REPLACE TABLE " + tableName + " AS SELECT * FROM (VALUES (1), (2)) AS t(a) ");
} catch (Exception e) {
RuntimeException trinoException = getTrinoExceptionCause(e);
try {
throw new AssertionError("Unexpected concurrent CREATE OR REPLACE failure", trinoException);
} catch (Throwable verifyFailure) {
if (verifyFailure != e) {
verifyFailure.addSuppressed(e);
}
throw verifyFailure;
}
}
});
return null;
}));
// Other 4 threads continue try to read the same table, none of the reads should fail.
IntStream.range(0, threads)
.forEach(threadNumber -> futures.add(executor.submit(() -> {
barrier.await(30, SECONDS);
IntStream.range(0, numOfReads).forEach(readIndex -> {
try {
MaterializedResult result = computeActual("SELECT * FROM " + tableName);
if (result.getRowCount() == 1) {
assertEqualsIgnoreOrder(result.getMaterializedRows(), List.of(new MaterializedRow(List.of(1))));
}
else {
assertEqualsIgnoreOrder(result.getMaterializedRows(), List.of(new MaterializedRow(List.of(1)), new MaterializedRow(List.of(2))));
}
}
catch (Exception e) {
RuntimeException trinoException = getTrinoExceptionCause(e);
try {
throw new AssertionError("Unexpected concurrent CREATE OR REPLACE failure", trinoException);
}
catch (Throwable verifyFailure) {
if (verifyFailure != e) {
verifyFailure.addSuppressed(e);
}
throw verifyFailure;
}
}
});
return null;
})));
futures.forEach(Futures::getUnchecked);
getQueryRunner().execute("CREATE OR REPLACE TABLE " + tableName + " AS SELECT * FROM (VALUES (1), (2), (3)) AS t(a)");
assertThat(query("SELECT * FROM " + tableName)).matches("VALUES 1, 2, 3");
}
finally {
executor.shutdownNow();
executor.awaitTermination(30, SECONDS);
}
}

private long getTableVersion(String tableName)
{
return (Long) computeActual(format("SELECT max(version) FROM \"%s$history\"", tableName)).getOnlyValue();
}

private void assertTableOperation(String tableName, long version, String operation)
{
assertQuery("SELECT operation FROM \"%s$history\" WHERE version = %s".formatted(tableName, version),
"VALUES '%s'".formatted(operation));
}
protected List<String> listCheckpointFiles(String transactionLogDirectory)
{
return listFiles(transactionLogDirectory).stream()
Expand Down

0 comments on commit b2a6b28

Please sign in to comment.