Skip to content

Commit

Permalink
Use Storage write API in BigQuery connector
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Nov 12, 2023
1 parent 6d1ed93 commit 6d1e190
Show file tree
Hide file tree
Showing 15 changed files with 189 additions and 44 deletions.
11 changes: 11 additions & 0 deletions plugin/trino-bigquery/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@
</exclusions>
</dependency>

<dependency>
<groupId>com.google.api.grpc</groupId>
<artifactId>proto-google-common-protos</artifactId>
</dependency>

<dependency>
<groupId>com.google.auth</groupId>
<artifactId>google-auth-library-credentials</artifactId>
Expand Down Expand Up @@ -304,6 +309,12 @@
<artifactId>httpcore</artifactId>
</dependency>

<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20231013</version>
</dependency>

<dependency>
<groupId>org.threeten</groupId>
<artifactId>threetenbp</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ protected void setup(Binder binder)
{
// BigQuery related
binder.bind(BigQueryReadClientFactory.class).in(Scopes.SINGLETON);
binder.bind(BigQueryWriteClientFactory.class).in(Scopes.SINGLETON);
binder.bind(BigQueryClientFactory.class).in(Scopes.SINGLETON);

// Connector implementation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public enum BigQueryErrorCode
BIGQUERY_UNSUPPORTED_OPERATION(5, USER_ERROR),
BIGQUERY_INVALID_STATEMENT(6, USER_ERROR),
BIGQUERY_PROXY_SSL_INITIALIZATION_FAILED(7, EXTERNAL),
BIGQUERY_BAD_WRITE(8, EXTERNAL),
/**/;

private final ErrorCode errorCode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider;
import com.google.cloud.bigquery.storage.v1.BigQueryReadSettings;
import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings;
import io.trino.spi.connector.ConnectorSession;

interface BigQueryGrpcOptionsConfigurer
Expand All @@ -27,5 +28,12 @@ default BigQueryReadSettings.Builder configure(BigQueryReadSettings.Builder buil
return builder.setTransportChannelProvider(configure(channelBuilder, session).build());
}

@Override
default BigQueryWriteSettings.Builder configure(BigQueryWriteSettings.Builder builder, ConnectorSession session)
{
InstantiatingGrpcChannelProvider.Builder channelBuilder = ((InstantiatingGrpcChannelProvider) builder.getTransportChannelProvider()).toBuilder();
return builder.setTransportChannelProvider(configure(channelBuilder, session).build());
}

InstantiatingGrpcChannelProvider.Builder configure(InstantiatingGrpcChannelProvider.Builder channelBuilder, ConnectorSession session);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@

import com.google.cloud.bigquery.BigQueryOptions;
import com.google.cloud.bigquery.storage.v1.BigQueryReadSettings;
import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings;
import io.trino.spi.connector.ConnectorSession;

interface BigQueryOptionsConfigurer
{
BigQueryOptions.Builder configure(BigQueryOptions.Builder builder, ConnectorSession session);

BigQueryReadSettings.Builder configure(BigQueryReadSettings.Builder builder, ConnectorSession session);

BigQueryWriteSettings.Builder configure(BigQueryWriteSettings.Builder builder, ConnectorSession session);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,49 @@
*/
package io.trino.plugin.bigquery;

import com.google.cloud.bigquery.InsertAllRequest;
import com.google.cloud.bigquery.TableId;
import com.google.api.core.ApiFuture;
import com.google.cloud.bigquery.storage.v1.AppendRowsResponse;
import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient;
import com.google.cloud.bigquery.storage.v1.CreateWriteStreamRequest;
import com.google.cloud.bigquery.storage.v1.JsonStreamWriter;
import com.google.cloud.bigquery.storage.v1.TableName;
import com.google.cloud.bigquery.storage.v1.WriteStream;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.spi.Page;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.ConnectorPageSink;
import io.trino.spi.connector.ConnectorPageSinkId;
import io.trino.spi.type.Type;
import org.json.JSONArray;
import org.json.JSONObject;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

import static com.google.cloud.bigquery.storage.v1.WriteStream.Type.COMMITTED;
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_BAD_WRITE;
import static io.trino.plugin.bigquery.BigQueryTypeUtils.readNativeValue;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.completedFuture;

public class BigQueryPageSink
implements ConnectorPageSink
{
private final BigQueryClient client;
private final TableId tableId;
private final BigQueryWriteClient client;
private final TableName tableName;
private final List<String> columnNames;
private final List<Type> columnTypes;
private final ConnectorPageSinkId pageSinkId;
private final Optional<String> pageSinkIdColumnName;

public BigQueryPageSink(
BigQueryClient client,
BigQueryWriteClient client,
RemoteTableName remoteTableName,
List<String> columnNames,
List<Type> columnTypes,
Expand All @@ -64,28 +73,46 @@ public BigQueryPageSink(
this.pageSinkIdColumnName = requireNonNull(pageSinkIdColumnName, "pageSinkIdColumnName is null");
checkArgument(temporaryTableName.isPresent() == pageSinkIdColumnName.isPresent(),
"temporaryTableName.isPresent is not equal to pageSinkIdColumn.isPresent");
this.tableId = temporaryTableName
.map(tableName -> TableId.of(remoteTableName.getProjectId(), remoteTableName.getDatasetName(), tableName))
.orElseGet(remoteTableName::toTableId);
this.tableName = temporaryTableName
.map(tableName -> TableName.of(remoteTableName.getProjectId(), remoteTableName.getDatasetName(), tableName))
.orElseGet(remoteTableName::toTableName);
}

@Override
public CompletableFuture<?> appendPage(Page page)
{
InsertAllRequest.Builder batch = InsertAllRequest.newBuilder(tableId);
JSONArray batch = new JSONArray();
for (int position = 0; position < page.getPositionCount(); position++) {
Map<String, Object> row = new HashMap<>();
JSONObject row = new JSONObject();
pageSinkIdColumnName.ifPresent(column -> row.put(column, pageSinkId.getId()));
for (int channel = 0; channel < page.getChannelCount(); channel++) {
row.put(columnNames.get(channel), readNativeValue(columnTypes.get(channel), page.getBlock(channel), position));
}
batch.addRow(row);
batch.put(row);
}

client.insert(batch.build());
insertWithCommitted(batch);
return NOT_BLOCKED;
}

private void insertWithCommitted(JSONArray batch)
{
WriteStream stream = WriteStream.newBuilder().setType(COMMITTED).build();
CreateWriteStreamRequest createWriteStreamRequest = CreateWriteStreamRequest.newBuilder().setParent(tableName.toString()).setWriteStream(stream).build();
WriteStream writeStream = client.createWriteStream(createWriteStreamRequest);

try (JsonStreamWriter writer = JsonStreamWriter.newBuilder(writeStream.getName(), writeStream.getTableSchema(), client).build()) {
ApiFuture<AppendRowsResponse> future = writer.append(batch);
AppendRowsResponse response = future.get(); // Throw error
if (response.hasError()) {
throw new TrinoException(BIGQUERY_BAD_WRITE, format("Response has error: %s", response.getError().getMessage()));
}
}
catch (Exception e) {
throw new TrinoException(BIGQUERY_BAD_WRITE, "Failed to insert rows", e);
}
}

@Override
public CompletableFuture<Collection<Slice>> finish()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
public class BigQueryPageSinkProvider
implements ConnectorPageSinkProvider
{
private final BigQueryClientFactory clientFactory;
private final BigQueryWriteClientFactory clientFactory;

@Inject
public BigQueryPageSinkProvider(BigQueryClientFactory clientFactory)
public BigQueryPageSinkProvider(BigQueryWriteClientFactory clientFactory)
{
this.clientFactory = requireNonNull(clientFactory, "clientFactory is null");
}
Expand All @@ -42,7 +42,7 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa
{
BigQueryOutputTableHandle handle = (BigQueryOutputTableHandle) outputTableHandle;
return new BigQueryPageSink(
clientFactory.createBigQueryClient(session),
clientFactory.create(session),
handle.getRemoteTableName(),
handle.getColumnNames(),
handle.getColumnTypes(),
Expand All @@ -56,7 +56,7 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa
{
BigQueryInsertTableHandle handle = (BigQueryInsertTableHandle) insertTableHandle;
return new BigQueryPageSink(
clientFactory.createBigQueryClient(session),
clientFactory.create(session),
handle.getRemoteTableName(),
handle.getColumnNames(),
handle.getColumnTypes(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public static String dateToStringConverter(Object value)
return "'" + date + "'";
}

private static String datetimeToStringConverter(Object value)
public static String datetimeToStringConverter(Object value)
{
long epochMicros = (long) value;
long epochSeconds = floorDiv(epochMicros, MICROSECONDS_PER_SECOND);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package io.trino.plugin.bigquery;

import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.SqlRow;
Expand All @@ -24,16 +24,15 @@
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import jakarta.annotation.Nullable;
import org.json.JSONArray;
import org.json.JSONObject;

import java.time.LocalDate;
import java.time.format.DateTimeFormatter;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static io.trino.plugin.bigquery.BigQueryType.datetimeToStringConverter;
import static io.trino.plugin.bigquery.BigQueryType.timestampToStringConverter;
import static io.trino.plugin.bigquery.BigQueryType.toZonedDateTime;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.type.BigintType.BIGINT;
Expand All @@ -45,19 +44,15 @@
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS;
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND;
import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static java.lang.Math.floorDiv;
import static java.lang.Math.floorMod;
import static java.time.ZoneOffset.UTC;
import static java.util.Collections.unmodifiableMap;

public final class BigQueryTypeUtils
{
private static final long MIN_SUPPORTED_DATE = LocalDate.parse("0001-01-01").toEpochDay();
private static final long MAX_SUPPORTED_DATE = LocalDate.parse("9999-12-31").toEpochDay();

private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormatter.ofPattern("uuuu-MM-dd");
private static final DateTimeFormatter DATETIME_FORMATTER = DateTimeFormatter.ofPattern("uuuu-MM-dd HH:mm:ss.SSSSSS");

private BigQueryTypeUtils() {}

Expand All @@ -73,10 +68,10 @@ public static Object readNativeValue(Type type, Block block, int position)
return BOOLEAN.getBoolean(block, position);
}
if (type.equals(TINYINT)) {
return TINYINT.getByte(block, position);
return type.getLong(block, position);
}
if (type.equals(SMALLINT)) {
return SMALLINT.getShort(block, position);
return SMALLINT.getLong(block, position);
}
if (type.equals(INTEGER)) {
return INTEGER.getInt(block, position);
Expand All @@ -94,33 +89,34 @@ public static Object readNativeValue(Type type, Block block, int position)
return varcharType.getSlice(block, position).toStringUtf8();
}
if (type.equals(VARBINARY)) {
return Base64.getEncoder().encodeToString(VARBINARY.getSlice(block, position).getBytes());
return ByteString.copyFrom(VARBINARY.getSlice(block, position).getBytes());
}
if (type.equals(DATE)) {
int days = DATE.getInt(block, position);
if (days < MIN_SUPPORTED_DATE || days > MAX_SUPPORTED_DATE) {
throw new TrinoException(NOT_SUPPORTED, "BigQuery supports dates between 0001-01-01 and 9999-12-31 but got " + LocalDate.ofEpochDay(days));
}
return DATE_FORMATTER.format(LocalDate.ofEpochDay(days));
}
if (type.equals(TIMESTAMP_MICROS)) {
long epochMicros = TIMESTAMP_MICROS.getLong(block, position);
long epochSeconds = floorDiv(epochMicros, MICROSECONDS_PER_SECOND);
int nanoAdjustment = floorMod(epochMicros, MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND;
return DATETIME_FORMATTER.format(toZonedDateTime(epochSeconds, nanoAdjustment, UTC));
return datetimeToStringConverter(epochMicros);
}
if (type.equals(TIMESTAMP_TZ_MICROS)) {
LongTimestampWithTimeZone timestamp = (LongTimestampWithTimeZone) TIMESTAMP_TZ_MICROS.getObject(block, position);
return timestampToStringConverter(timestamp);
}
if (type instanceof ArrayType arrayType) {
Block arrayBlock = arrayType.getObject(block, position);
ImmutableList.Builder<Object> list = ImmutableList.builderWithExpectedSize(arrayBlock.getPositionCount());
JSONArray list = new JSONArray();
for (int i = 0; i < arrayBlock.getPositionCount(); i++) {
Object element = readNativeValue(arrayType.getElementType(), arrayBlock, i);
if (element == null) {
throw new TrinoException(NOT_SUPPORTED, "BigQuery does not support null elements in arrays");
}
list.add(element);
list.put(element);
}
return list.build();
return list;
}
if (type instanceof RowType rowType) {
SqlRow sqlRow = rowType.getObject(block, position);
Expand All @@ -131,13 +127,13 @@ public static Object readNativeValue(Type type, Block block, int position)
}

int rawIndex = sqlRow.getRawIndex();
Map<String, Object> rowValue = new HashMap<>();
JSONObject rowValue = new JSONObject();
for (int fieldIndex = 0; fieldIndex < sqlRow.getFieldCount(); fieldIndex++) {
String fieldName = rowType.getFields().get(fieldIndex).getName().orElseThrow(() -> new IllegalArgumentException("Field name must exist in BigQuery"));
Object fieldValue = readNativeValue(fieldTypes.get(fieldIndex), sqlRow.getRawFieldBlock(fieldIndex), rawIndex);
rowValue.put(fieldName, fieldValue);
}
return unmodifiableMap(rowValue);
return rowValue;
}

throw new TrinoException(NOT_SUPPORTED, "Unsupported type: " + type);
Expand Down

0 comments on commit 6d1e190

Please sign in to comment.