Skip to content

Commit

Permalink
Implement CREATE OR REPLACE TABLE for delta lake connector
Browse files Browse the repository at this point in the history
  • Loading branch information
Praveen2112 committed Apr 2, 2024
1 parent a502b7b commit 65a9146
Show file tree
Hide file tree
Showing 12 changed files with 1,000 additions and 45 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.plugin.deltalake.DeltaLakeColumnType.PARTITION_KEY;
Expand All @@ -43,6 +44,8 @@ public class DeltaLakeOutputTableHandle
private final ColumnMappingMode columnMappingMode;
private final OptionalInt maxColumnId;
private final String schemaString;
private final boolean replace;
private final OptionalLong readVersion;
private final ProtocolEntry protocolEntry;

@JsonCreator
Expand All @@ -58,6 +61,8 @@ public DeltaLakeOutputTableHandle(
@JsonProperty("schemaString") String schemaString,
@JsonProperty("columnMappingMode") ColumnMappingMode columnMappingMode,
@JsonProperty("maxColumnId") OptionalInt maxColumnId,
@JsonProperty("replace") boolean replace,
@JsonProperty("readVersion") OptionalLong readVersion,
@JsonProperty("protocolEntry") ProtocolEntry protocolEntry)
{
this.schemaName = requireNonNull(schemaName, "schemaName is null");
Expand All @@ -71,6 +76,8 @@ public DeltaLakeOutputTableHandle(
this.schemaString = requireNonNull(schemaString, "schemaString is null");
this.columnMappingMode = requireNonNull(columnMappingMode, "columnMappingMode is null");
this.maxColumnId = requireNonNull(maxColumnId, "maxColumnId is null");
this.replace = replace;
this.readVersion = requireNonNull(readVersion, "readVersion is null");
this.protocolEntry = requireNonNull(protocolEntry, "protocolEntry is null");
}

Expand Down Expand Up @@ -149,6 +156,18 @@ public OptionalInt getMaxColumnId()
return maxColumnId;
}

@JsonProperty
public boolean isReplace()
{
return replace;
}

@JsonProperty
public OptionalLong getReadVersion()
{
return readVersion;
}

@JsonProperty
public ProtocolEntry getProtocolEntry()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ public interface DeltaLakeMetastore

void createTable(Table table, PrincipalPrivileges principalPrivileges);

void replaceTable(Table table, PrincipalPrivileges principalPrivileges);

void dropTable(SchemaTableName schemaTableName, String tableLocation, boolean deleteData);

void renameTable(SchemaTableName from, SchemaTableName to);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ public void createTable(Table table, PrincipalPrivileges principalPrivileges)
delegate.createTable(table, principalPrivileges);
}

@Override
public void replaceTable(Table table, PrincipalPrivileges principalPrivileges)
{
delegate.replaceTable(table.getDatabaseName(), table.getTableName(), table, principalPrivileges);
}

@Override
public void dropTable(SchemaTableName schemaTableName, String tableLocation, boolean deleteData)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.Futures;
import io.airlift.concurrent.MoreFutures;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
Expand Down Expand Up @@ -45,6 +46,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -67,13 +69,16 @@
import static io.trino.SystemSessionProperties.ENABLE_DYNAMIC_FILTERING;
import static io.trino.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static io.trino.plugin.base.util.Closables.closeAllSuppress;
import static io.trino.plugin.deltalake.DeltaLakeMetadata.CREATE_OR_REPLACE_TABLE_AS_OPERATION;
import static io.trino.plugin.deltalake.DeltaLakeMetadata.CREATE_OR_REPLACE_TABLE_OPERATION;
import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.DELTA_CATALOG;
import static io.trino.plugin.deltalake.DeltaLakeQueryRunner.createDockerizedDeltaLakeQueryRunner;
import static io.trino.plugin.deltalake.DeltaLakeSessionProperties.EXTENDED_STATISTICS_COLLECT_ON_WRITE;
import static io.trino.plugin.deltalake.TestingDeltaLakeUtils.getConnectorService;
import static io.trino.plugin.deltalake.TestingDeltaLakeUtils.getTableActiveFiles;
import static io.trino.plugin.deltalake.transactionlog.TransactionLogUtil.TRANSACTION_LOG_DIRECTORY;
import static io.trino.plugin.hive.TestingThriftHiveMetastoreBuilder.testingThriftHiveMetastoreBuilder;
import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder;
import static io.trino.testing.QueryAssertions.getTrinoExceptionCause;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.DELETE_TABLE;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_TABLE;
Expand Down Expand Up @@ -2294,6 +2299,140 @@ public void testConcurrentInsertsReconciliationForBlindInserts()
testConcurrentInsertsReconciliationForBlindInserts(true);
}

@Test
public void testCreateOrReplaceTable()
{
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_table", " AS SELECT BIGINT '42' a, DOUBLE '-38.5' b")) {
assertThat(query("SELECT CAST(a AS bigint), b FROM " + table.getName()))
.matches("VALUES (BIGINT '42', -385e-1)");

assertUpdate("CREATE OR REPLACE TABLE %s (a bigint, b double)".formatted(table.getName()));
assertQueryReturnsEmptyResult("SELECT * FROM " + table.getName());

assertThat(getTableVersion(table.getName())).isEqualTo(1);
assertTableOperation(table.getName(), 1, CREATE_OR_REPLACE_TABLE_OPERATION);
}
}

@Test
public void testCreateOrReplaceTableAs()
{
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_table", " AS SELECT BIGINT '42' a, DOUBLE '-38.5' b")) {
assertThat(query("SELECT CAST(a AS bigint), b FROM " + table.getName()))
.matches("VALUES (BIGINT '42', -385e-1)");

assertUpdate("CREATE OR REPLACE TABLE %s AS SELECT BIGINT '-53' a, DOUBLE '49.6' b".formatted(table.getName()), 1);
assertThat(query("SELECT CAST(a AS bigint), b FROM " + table.getName()))
.matches("VALUES (BIGINT '-53', 496e-1)");

assertThat(1L).isEqualTo(getTableVersion(table.getName()));
assertTableOperation(table.getName(), 1, CREATE_OR_REPLACE_TABLE_AS_OPERATION);
}
}

@Test
public void testCreateOrReplaceTableChangeColumnNamesAndTypes()
{
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_table", " AS SELECT BIGINT '42' a, DOUBLE '-38.5' b")) {
assertThat(query("SELECT CAST(a AS bigint), b FROM " + table.getName()))
.matches("VALUES (BIGINT '42', -385e-1)");

assertUpdate("CREATE OR REPLACE TABLE " + table.getName() + " AS SELECT VARCHAR 'test' c, VARCHAR 'test2' d", 1);
assertThat(query("SELECT c, d FROM " + table.getName()))
.matches("VALUES (VARCHAR 'test', VARCHAR 'test2')");

assertThat(1L).isEqualTo(getTableVersion(table.getName()));
assertTableOperation(table.getName(), 1, CREATE_OR_REPLACE_TABLE_AS_OPERATION);
}
}

@RepeatedTest(3)
// Test from BaseConnectorTest
public void testCreateOrReplaceTableConcurrently()
throws Exception
{
int threads = 4;
int numOfCreateOrReplaceStatements = 4;
int numOfReads = 16;
CyclicBarrier barrier = new CyclicBarrier(threads + 1);
ExecutorService executor = newFixedThreadPool(threads + 1);
List<Future<?>> futures = new ArrayList<>();
try (TestTable table = new TestTable(getQueryRunner()::execute, "test_create_or_replace", "(col integer)")) {
String tableName = table.getName();

getQueryRunner().execute("CREATE OR REPLACE TABLE " + tableName + " AS SELECT 1 a");
assertThat(query("SELECT * FROM " + tableName)).matches("VALUES 1");

// One thread submits some CREATE OR REPLACE statements
futures.add(executor.submit(() -> {
barrier.await(30, SECONDS);
IntStream.range(0, numOfCreateOrReplaceStatements).forEach(index -> {
try {
getQueryRunner().execute("CREATE OR REPLACE TABLE " + tableName + " AS SELECT * FROM (VALUES (1), (2)) AS t(a) ");
} catch (Exception e) {
RuntimeException trinoException = getTrinoExceptionCause(e);
try {
throw new AssertionError("Unexpected concurrent CREATE OR REPLACE failure", trinoException);
} catch (Throwable verifyFailure) {
if (verifyFailure != e) {
verifyFailure.addSuppressed(e);
}
throw verifyFailure;
}
}
});
return null;
}));
// Other 4 threads continue try to read the same table, none of the reads should fail.
IntStream.range(0, threads)
.forEach(threadNumber -> futures.add(executor.submit(() -> {
barrier.await(30, SECONDS);
IntStream.range(0, numOfReads).forEach(readIndex -> {
try {
MaterializedResult result = computeActual("SELECT * FROM " + tableName);
if (result.getRowCount() == 1) {
assertEqualsIgnoreOrder(result.getMaterializedRows(), List.of(new MaterializedRow(List.of(1))));
}
else {
assertEqualsIgnoreOrder(result.getMaterializedRows(), List.of(new MaterializedRow(List.of(1)), new MaterializedRow(List.of(2))));
}
}
catch (Exception e) {
RuntimeException trinoException = getTrinoExceptionCause(e);
try {
throw new AssertionError("Unexpected concurrent CREATE OR REPLACE failure", trinoException);
}
catch (Throwable verifyFailure) {
if (verifyFailure != e) {
verifyFailure.addSuppressed(e);
}
throw verifyFailure;
}
}
});
return null;
})));
futures.forEach(Futures::getUnchecked);
getQueryRunner().execute("CREATE OR REPLACE TABLE " + tableName + " AS SELECT * FROM (VALUES (1), (2), (3)) AS t(a)");
assertThat(query("SELECT * FROM " + tableName)).matches("VALUES 1, 2, 3");
}
finally {
executor.shutdownNow();
executor.awaitTermination(30, SECONDS);
}
}

private long getTableVersion(String tableName)
{
return (Long) computeActual(format("SELECT max(version) FROM \"%s$history\"", tableName)).getOnlyValue();
}

private void assertTableOperation(String tableName, long version, String operation)
{
assertQuery("SELECT operation FROM \"%s$history\" WHERE version = %s".formatted(tableName, version),
"VALUES '%s'".formatted(operation));
}

private void testConcurrentInsertsReconciliationForBlindInserts(boolean partitioned)
throws Exception
{
Expand Down

0 comments on commit 65a9146

Please sign in to comment.