diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 3f81f541b..5e3b9dca0 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -169,7 +169,7 @@ cc_library( "@avro", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/types:any", - "@com_google_googleapis//google/cloud/bigquery/storage/v1beta1:storage_cc_grpc", + "@com_google_googleapis//google/cloud/bigquery/storage/v1:storage_cc_grpc", "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", ], @@ -219,7 +219,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:any", "@com_google_absl//absl/types:variant", - "@com_google_googleapis//google/cloud/bigquery/storage/v1beta1:storage_cc_grpc", + "@com_google_googleapis//google/cloud/bigquery/storage/v1:storage_cc_grpc", "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", ], diff --git a/tensorflow_io/core/kernels/bigquery/bigquery_dataset_op.cc b/tensorflow_io/core/kernels/bigquery/bigquery_dataset_op.cc index b724a93e4..6d2a5cce4 100644 --- a/tensorflow_io/core/kernels/bigquery/bigquery_dataset_op.cc +++ b/tensorflow_io/core/kernels/bigquery/bigquery_dataset_op.cc @@ -108,7 +108,7 @@ class BigQueryDatasetOp : public DatasetOpKernel { std::vector default_values_; std::vector typed_default_values_; int64 offset_; - apiv1beta1::DataFormat data_format_; + apiv1::DataFormat data_format_; class Dataset : public DatasetBase { public: @@ -120,7 +120,7 @@ class BigQueryDatasetOp : public DatasetOpKernel { std::vector selected_fields, std::vector output_types, std::vector typed_default_values, int64 offset_, - apiv1beta1::DataFormat data_format) + apiv1::DataFormat data_format) : DatasetBase(DatasetContext(ctx)), client_resource_(client_resource), output_types_vector_(output_types_vector), @@ -134,10 +134,10 @@ class BigQueryDatasetOp : public DatasetOpKernel { data_format_(data_format) { client_resource_->Ref(); - if (data_format == apiv1beta1::DataFormat::AVRO) { + if (data_format == apiv1::DataFormat::AVRO) { std::istringstream istream(schema); avro::compileJsonSchema(istream, *avro_schema_); - } else if (data_format == apiv1beta1::DataFormat::ARROW) { + } else if (data_format == apiv1::DataFormat::ARROW) { auto buffer_ = std::make_shared( reinterpret_cast(&schema[0]), schema.length()); @@ -158,11 +158,11 @@ class BigQueryDatasetOp : public DatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string &prefix) const override { - if (data_format_ == apiv1beta1::DataFormat::AVRO) { + if (data_format_ == apiv1::DataFormat::AVRO) { return std::unique_ptr( new BigQueryReaderAvroDatasetIterator( {this, strings::StrCat(prefix, "::BigQueryAvroDataset")})); - } else if (data_format_ == apiv1beta1::DataFormat::ARROW) { + } else if (data_format_ == apiv1::DataFormat::ARROW) { return std::unique_ptr( new BigQueryReaderArrowDatasetIterator( {this, strings::StrCat(prefix, "::BigQueryArrowDataset")})); @@ -229,7 +229,7 @@ class BigQueryDatasetOp : public DatasetOpKernel { const std::unique_ptr avro_schema_; const int64 offset_; std::shared_ptr<::arrow::Schema> arrow_schema_; - const apiv1beta1::DataFormat data_format_; + const apiv1::DataFormat data_format_; }; }; diff --git a/tensorflow_io/core/kernels/bigquery/bigquery_kernels.cc b/tensorflow_io/core/kernels/bigquery/bigquery_kernels.cc index ea6043e6c..5e80b956f 100644 --- a/tensorflow_io/core/kernels/bigquery/bigquery_kernels.cc +++ b/tensorflow_io/core/kernels/bigquery/bigquery_kernels.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace { -namespace apiv1beta1 = ::google::cloud::bigquery::storage::v1beta1; +namespace apiv1 = ::google::cloud::bigquery::storage::v1; class BigQueryClientOp : public OpKernel { public: @@ -105,35 +105,30 @@ class BigQueryReadSessionOp : public OpKernel { ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource)); core::ScopedUnref scoped_unref(client_resource); - apiv1beta1::CreateReadSessionRequest createReadSessionRequest; - createReadSessionRequest.mutable_table_reference()->set_project_id( - project_id_); - createReadSessionRequest.mutable_table_reference()->set_dataset_id( - dataset_id_); - createReadSessionRequest.mutable_table_reference()->set_table_id(table_id_); + apiv1::CreateReadSessionRequest createReadSessionRequest; createReadSessionRequest.set_parent(parent_); - *createReadSessionRequest.mutable_read_options() - ->mutable_selected_fields() = {selected_fields_.begin(), - selected_fields_.end()}; - createReadSessionRequest.mutable_read_options()->set_row_restriction( - row_restriction_); - createReadSessionRequest.set_requested_streams(requested_streams_); - createReadSessionRequest.set_sharding_strategy( - apiv1beta1::ShardingStrategy::BALANCED); - createReadSessionRequest.set_format(data_format_); + apiv1::ReadSession* read_session = + createReadSessionRequest.mutable_read_session(); + read_session->set_table(strings::Printf( + "projects/%s/datasets/%s/tables/%s", project_id_.c_str(), + dataset_id_.c_str(), table_id_.c_str())); + read_session->set_data_format(data_format_); + *read_session->mutable_read_options()->mutable_selected_fields() = { + selected_fields_.begin(), selected_fields_.end()}; + read_session->mutable_read_options()->set_row_restriction(row_restriction_); + createReadSessionRequest.set_max_stream_count(requested_streams_); + VLOG(3) << "createReadSessionRequest: " << createReadSessionRequest.DebugString(); ::grpc::ClientContext context; - context.AddMetadata( - "x-goog-request-params", - strings::Printf("table_reference.dataset_id=%s&table_" - "reference.project_id=%s", - dataset_id_.c_str(), project_id_.c_str())); + context.AddMetadata("x-goog-request-params", + strings::Printf("read_session.table=%s", + read_session->table().c_str())); context.set_deadline(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN))); - std::shared_ptr readSessionResponse = - std::make_shared(); + std::shared_ptr readSessionResponse = + std::make_shared(); VLOG(3) << "calling readSession"; ::grpc::Status status = client_resource->GetStub("")->CreateReadSession( &context, createReadSessionRequest, readSessionResponse.get()); @@ -155,13 +150,13 @@ class BigQueryReadSessionOp : public OpKernel { Tensor* schema_t = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output("schema", {}, &schema_t)); - if (data_format_ == apiv1beta1::DataFormat::AVRO) { + if (data_format_ == apiv1::DataFormat::AVRO) { OP_REQUIRES(ctx, readSessionResponse->has_avro_schema(), errors::InvalidArgument("AVRO schema is missing")); VLOG(3) << "avro schema:" << readSessionResponse->avro_schema().schema(); schema_t->scalar()() = readSessionResponse->avro_schema().schema(); - } else if (data_format_ == apiv1beta1::DataFormat::ARROW) { + } else if (data_format_ == apiv1::DataFormat::ARROW) { OP_REQUIRES(ctx, readSessionResponse->has_arrow_schema(), errors::InvalidArgument("ARROW schema is missing")); VLOG(3) << "arrow schema:" @@ -183,7 +178,7 @@ class BigQueryReadSessionOp : public OpKernel { std::vector output_types_; string row_restriction_; int requested_streams_; - apiv1beta1::DataFormat data_format_; + apiv1::DataFormat data_format_; mutex mu_; ContainerInfo cinfo_ TF_GUARDED_BY(mu_); diff --git a/tensorflow_io/core/kernels/bigquery/bigquery_lib.cc b/tensorflow_io/core/kernels/bigquery/bigquery_lib.cc index 223502cf5..276d7a1d4 100644 --- a/tensorflow_io/core/kernels/bigquery/bigquery_lib.cc +++ b/tensorflow_io/core/kernels/bigquery/bigquery_lib.cc @@ -78,11 +78,11 @@ string GrpcStatusToString(const ::grpc::Status& status) { } Status GetDataFormat(string data_format_str, - apiv1beta1::DataFormat* data_format) { + apiv1::DataFormat* data_format) { if (data_format_str == "ARROW") { - *data_format = apiv1beta1::DataFormat::ARROW; + *data_format = apiv1::DataFormat::ARROW; } else if (data_format_str == "AVRO") { - *data_format = apiv1beta1::DataFormat::AVRO; + *data_format = apiv1::DataFormat::AVRO; } else { return errors::Internal("Unsupported data format: " + data_format_str); } diff --git a/tensorflow_io/core/kernels/bigquery/bigquery_lib.h b/tensorflow_io/core/kernels/bigquery/bigquery_lib.h index d4b9305d0..518f777b2 100644 --- a/tensorflow_io/core/kernels/bigquery/bigquery_lib.h +++ b/tensorflow_io/core/kernels/bigquery/bigquery_lib.h @@ -38,7 +38,7 @@ limitations under the License. #include "arrow/buffer.h" #include "arrow/io/memory.h" #include "arrow/ipc/api.h" -#include "google/cloud/bigquery/storage/v1beta1/storage.grpc.pb.h" +#include "google/cloud/bigquery/storage/v1/storage.grpc.pb.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" @@ -50,18 +50,18 @@ limitations under the License. namespace tensorflow { -namespace apiv1beta1 = ::google::cloud::bigquery::storage::v1beta1; +namespace apiv1 = ::google::cloud::bigquery::storage::v1; static constexpr int kMaxReceiveMessageSize = -1; // Disabled Status GrpcStatusToTfStatus(const ::grpc::Status &status); string GrpcStatusToString(const ::grpc::Status &status); Status GetDataFormat(string data_format_str, - apiv1beta1::DataFormat *data_format); + apiv1::DataFormat *data_format); class BigQueryClientResource : public ResourceBase { public: explicit BigQueryClientResource( - std::function( + std::function( const string &read_stream)> stub_factory) : stub_factory_(stub_factory) {} @@ -80,10 +80,10 @@ class BigQueryClientResource : public ResourceBase { args.SetString("read_stream", read_stream); auto channel = ::grpc::CreateCustomChannel(server_name, creds, args); VLOG(3) << "Creating GRPC channel"; - return absl::make_unique(channel); + return absl::make_unique(channel); }) {} - apiv1beta1::BigQueryStorage::Stub *GetStub(const string &read_stream) + apiv1::BigQueryRead::Stub *GetStub(const string &read_stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (stubs_.find(read_stream) == stubs_.end()) { auto stub = stub_factory_(read_stream); @@ -95,11 +95,11 @@ class BigQueryClientResource : public ResourceBase { string DebugString() const override { return "BigQueryClientResource"; } private: - std::function( + std::function( const string &)> stub_factory_; mutex mu_; - std::unordered_map> + std::unordered_map> stubs_ TF_GUARDED_BY(mu_); }; @@ -156,11 +156,9 @@ class BigQueryReaderDatasetIteratorBase : public DatasetIterator { return OkStatus(); } - apiv1beta1::ReadRowsRequest readRowsRequest; - readRowsRequest.mutable_read_position()->mutable_stream()->set_name( - this->dataset()->stream()); - readRowsRequest.mutable_read_position()->set_offset( - this->dataset()->offset()); + apiv1::ReadRowsRequest readRowsRequest; + readRowsRequest.set_read_stream(this->dataset()->stream()); + readRowsRequest.set_offset(this->dataset()->offset()); read_rows_context_ = absl::make_unique<::grpc::ClientContext>(); // The deadline is for the entire ReadRows (not a single message receipt), @@ -169,14 +167,14 @@ class BigQueryReaderDatasetIteratorBase : public DatasetIterator { std::chrono::hours(24)); read_rows_context_->AddMetadata( "x-goog-request-params", - absl::StrCat("read_position.stream.name=", - readRowsRequest.read_position().stream().name())); + absl::StrCat("read_stream=", + readRowsRequest.read_stream())); VLOG(3) << "getting reader, stream: " - << readRowsRequest.read_position().stream().DebugString(); + << readRowsRequest.read_stream(); reader_ = this->dataset() ->client_resource() - ->GetStub(readRowsRequest.read_position().stream().name()) + ->GetStub(readRowsRequest.read_stream()) ->ReadRows(read_rows_context_.get(), readRowsRequest); return OkStatus(); @@ -191,9 +189,9 @@ class BigQueryReaderDatasetIteratorBase : public DatasetIterator { int current_row_index_ = 0; mutex mu_; std::unique_ptr<::grpc::ClientContext> read_rows_context_ TF_GUARDED_BY(mu_); - std::unique_ptr<::grpc::ClientReader> reader_ + std::unique_ptr<::grpc::ClientReader> reader_ TF_GUARDED_BY(mu_); - std::unique_ptr response_ TF_GUARDED_BY(mu_); + std::unique_ptr response_ TF_GUARDED_BY(mu_); }; // BigQuery reader for Arrow serialized data. @@ -213,11 +211,11 @@ class BigQueryReaderArrowDatasetIterator TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override { if (this->response_ && this->response_->has_arrow_record_batch() && this->current_row_index_ < - this->response_->arrow_record_batch().row_count()) { + this->response_->row_count()) { return OkStatus(); } - this->response_ = absl::make_unique(); + this->response_ = absl::make_unique(); if (!this->reader_->Read(this->response_.get())) { *end_of_sequence = true; return GrpcStatusToTfStatus(this->reader_->Finish()); @@ -315,11 +313,11 @@ class BigQueryReaderAvroDatasetIterator Status EnsureHasRow(bool *end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override { if (this->response_ && - this->current_row_index_ < this->response_->avro_rows().row_count()) { + this->current_row_index_ < this->response_->row_count()) { return OkStatus(); } - this->response_ = absl::make_unique(); + this->response_ = absl::make_unique(); VLOG(3) << "calling read"; if (!this->reader_->Read(this->response_.get())) { VLOG(3) << "no data"; diff --git a/tensorflow_io/core/kernels/tests/bigquery_test_client_op.cc b/tensorflow_io/core/kernels/tests/bigquery_test_client_op.cc index 51dda0399..9b68d63fb 100644 --- a/tensorflow_io/core/kernels/tests/bigquery_test_client_op.cc +++ b/tensorflow_io/core/kernels/tests/bigquery_test_client_op.cc @@ -18,7 +18,7 @@ limitations under the License. namespace tensorflow { namespace { -namespace apiv1beta1 = ::google::cloud::bigquery::storage::v1beta1; +namespace apiv1 = ::google::cloud::bigquery::storage::v1; class BigQueryTestClientOp : public OpKernel { public: @@ -58,13 +58,13 @@ class BigQueryTestClientOp : public OpKernel { std::shared_ptr channel = ::grpc::CreateChannel(this->fake_server_address_, grpc::InsecureChannelCredentials()); - auto stub = apiv1beta1::BigQueryStorage::NewStub(channel); + auto stub = apiv1::BigQueryRead::NewStub(channel); LOG(INFO) << "BigQueryTestClientOp waiting for connections"; channel->WaitForConnected( gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(15, GPR_TIMESPAN))); LOG(INFO) << "Done creating BigQueryTestClientOp Fake client"; - return absl::make_unique( + return absl::make_unique( channel); }); return OkStatus(); diff --git a/tests/test_bigquery.py b/tests/test_bigquery.py index 4eef2ecdd..c06c7fef1 100644 --- a/tests/test_bigquery.py +++ b/tests/test_bigquery.py @@ -34,8 +34,15 @@ BigQueryClient, ) # pylint: disable=wrong-import-order -import google.cloud.bigquery_storage_v1beta1.proto.storage_pb2_grpc as storage_pb2_grpc # pylint: disable=wrong-import-order -import google.cloud.bigquery_storage_v1beta1.proto.storage_pb2 as storage_pb2 # pylint: disable=wrong-import-order +import google.cloud.bigquery_storage_v1 as bigquery_storage # pylint: disable=wrong-import-order +from google.cloud.bigquery_storage_v1 import types as storage_pb2 # pylint: disable=wrong-import-order +from google.cloud.bigquery_storage_v1.services.big_query_read import ( + BigQueryReadServicer, + BigQueryReadStub, +) # pylint: disable=wrong-import-order +import google.cloud.bigquery_storage_v1.services.big_query_read.transports.grpc as bigquery_storage_grpc # pylint: disable=wrong-import-order + +storage_pb2_grpc = bigquery_storage_grpc.BigQueryReadGrpcTransport if sys.platform == "darwin": pytest.skip("TODO: macOS is failing", allow_module_level=True) @@ -44,7 +51,7 @@ tf.compat.v1.enable_eager_execution() -class FakeBigQueryServer(storage_pb2_grpc.BigQueryStorageServicer): +class FakeBigQueryServer(BigQueryReadServicer): """Fake server for Cloud BigQuery Storage API.""" def __init__( @@ -62,7 +69,9 @@ def __init__( self._grpc_server = grpc.server( concurrent.futures.ThreadPoolExecutor(max_workers=4) ) - storage_pb2_grpc.add_BigQueryStorageServicer_to_server(self, self._grpc_server) + bigquery_storage_grpc.BigQueryReadGrpcTransport.add_server( + self._grpc_server, self + ) port = self._grpc_server.add_insecure_port("localhost:0") self._endpoint = "localhost:" + str(port) print("started a fake server on :" + self._endpoint) @@ -101,31 +110,33 @@ def serialize_to_avro(rows, schema): def CreateReadSession(self, request, context): # pylint: disable=unused-argument """CreateReadSession""" print("called CreateReadSession on a fake server") - self._project_id = request.table_reference.project_id - self._table_id = request.table_reference.table_id - self._dataset_id = request.table_reference.dataset_id + table_path = request.read_session.table + parts = table_path.split("/") + self._project_id = parts[1] + self._dataset_id = parts[3] + self._table_id = parts[5] self._streams = [] response = storage_pb2.ReadSession() response.avro_schema.schema = self._avro_schema - for i in range(request.requested_streams): + for i in range(request.max_stream_count): stream_name = self._build_stream_name(i) self._streams.append(stream_name) - stream = response.streams.add() - stream.name = stream_name + stream = storage_pb2.ReadStream(name=stream_name) + response.streams.append(stream) return response def ReadRows(self, request, context): # pylint: disable=unused-argument """ReadRows""" print("called ReadRows on a fake server: %s" % str(request)) response = storage_pb2.ReadRowsResponse() - stream_index = self._streams.index(request.read_position.stream.name) + stream_index = self._streams.index(request.read_stream) if 0 <= stream_index < len(self._rows_per_stream): - rows = self._rows_per_stream[stream_index][request.read_position.offset :] + rows = self._rows_per_stream[stream_index][request.offset :] serialized_rows = FakeBigQueryServer.serialize_to_avro( rows, self._avro_schema ) response.avro_rows.serialized_binary_rows = serialized_rows - response.avro_rows.row_count = len(rows) + response.row_count = len(rows) yield response @@ -459,24 +470,23 @@ def _get_read_session( def test_fake_server(self): """Fake server test.""" channel = grpc.insecure_channel(BigqueryOpsTest.server.endpoint()) - stub = storage_pb2_grpc.BigQueryStorageStub(channel) + stub = BigQueryReadStub(channel) create_read_session_request = storage_pb2.CreateReadSessionRequest() - create_read_session_request.table_reference.project_id = self.GCP_PROJECT_ID - create_read_session_request.table_reference.dataset_id = self.DATASET_ID - create_read_session_request.table_reference.table_id = self.TABLE_ID - create_read_session_request.requested_streams = 2 + create_read_session_request.read_session.table = ( + "projects/%s/datasets/%s/tables/%s" + % (self.GCP_PROJECT_ID, self.DATASET_ID, self.TABLE_ID) + ) + create_read_session_request.max_stream_count = 2 read_session_response = stub.CreateReadSession(create_read_session_request) self.assertEqual(2, len(read_session_response.streams)) read_rows_request = storage_pb2.ReadRowsRequest() - read_rows_request.read_position.stream.name = read_session_response.streams[ - 0 - ].name + read_rows_request.read_stream = read_session_response.streams[0].name read_rows_response = stub.ReadRows(read_rows_request) - row = read_rows_response.next() + row = next(read_rows_response) self.assertEqual( FakeBigQueryServer.serialize_to_avro(self.STREAM_1_ROWS, self.AVRO_SCHEMA), row.avro_rows.serialized_binary_rows, @@ -484,11 +494,9 @@ def test_fake_server(self): self.assertEqual(len(self.STREAM_1_ROWS), row.avro_rows.row_count) read_rows_request = storage_pb2.ReadRowsRequest() - read_rows_request.read_position.stream.name = read_session_response.streams[ - 1 - ].name + read_rows_request.read_stream = read_session_response.streams[1].name read_rows_response = stub.ReadRows(read_rows_request) - row = read_rows_response.next() + row = next(read_rows_response) self.assertEqual( FakeBigQueryServer.serialize_to_avro(self.STREAM_2_ROWS, self.AVRO_SCHEMA), row.avro_rows.serialized_binary_rows,