diff --git a/tensorflow_io/bigquery.md b/tensorflow_io/bigquery.md index 986ba0e67..0eff4b66e 100644 --- a/tensorflow_io/bigquery.md +++ b/tensorflow_io/bigquery.md @@ -57,7 +57,7 @@ from tensorflow.python.framework import dtypes from tensorflow_io.bigquery import BigQueryClient from tensorflow_io.bigquery import BigQueryReadSession -GCP_PROJECT_ID = '' +GCP_PROJECT_ID = "" DATASET_GCP_PROJECT_ID = "bigquery-public-data" DATASET_ID = "samples" TABLE_ID = "wikipedia" @@ -68,20 +68,29 @@ def main(): read_session = client.read_session( "projects/" + GCP_PROJECT_ID, DATASET_GCP_PROJECT_ID, TABLE_ID, DATASET_ID, - ["title", + selected_fields=["title", "id", "num_characters", "language", "timestamp", "wp_namespace", "contributor_username"], - [dtypes.string, + output_types=[dtypes.string, dtypes.int64, dtypes.int64, dtypes.string, dtypes.int64, dtypes.int64, dtypes.string], + default_values=[ + "", + 0, + 0, + "", + 0, + 0, + "" + ], requested_streams=2, row_restriction="num_characters > 1000", data_format=BigQueryClient.DataFormat.AVRO) @@ -98,8 +107,8 @@ def main(): print("row %d: %s" % (row_index, row)) row_index += 1 -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + main() ``` @@ -127,10 +136,10 @@ dataset = streams_ds.interleave( Connector also supports reading BigQuery column with repeated mode (each field contains array of values with primitive type: Integer, Float, Boolean, String, but RECORD is not supported). In this case, selected_fields needs be a dictionary in a form like this: ```python - { "field_a_name": {"mode": BigQueryClient.FieldMode.REPEATED, output_type: dtypes.int64}, - "field_b_name": {"mode": BigQueryClient.FieldMode.NULLABLE, output_type: dtypes.string}, + { "field_a_name": {"mode": BigQueryClient.FieldMode.REPEATED, "output_type": dtypes.int64}, + "field_b_name": {"mode": BigQueryClient.FieldMode.NULLABLE, "output_type": dtypes.string, "default_value", ""}, ... - "field_x_name": {"mode": BigQueryClient.FieldMode.REQUIRED, output_type: dtypes.string} + "field_x_name": {"mode": BigQueryClient.FieldMode.REQUIRED, "output_type": dtypes.string} } ``` "mode" is BigQuery column attribute concept, it can be 'repeated', 'nullable' or 'required' (enum BigQueryClient.FieldMode.REPEATED, NULLABLE, REQUIRED).The output field order is unrelated to the order of fields in @@ -144,7 +153,7 @@ from tensorflow.python.framework import dtypes from tensorflow_io.bigquery import BigQueryClient from tensorflow_io.bigquery import BigQueryReadSession -GCP_PROJECT_ID = '' +GCP_PROJECT_ID = "" DATASET_GCP_PROJECT_ID = "bigquery-public-data" DATASET_ID = "certain_dataset" TABLE_ID = "certain_table_with_repeated_field" @@ -156,10 +165,10 @@ def main(): "projects/" + GCP_PROJECT_ID, DATASET_GCP_PROJECT_ID, TABLE_ID, DATASET_ID, selected_fiels={ - "field_a_name": {"mode": BigQueryClient.FieldMode.REPEATED, output_type: dtypes.int64}, - "field_b_name": {"mode": BigQueryClient.FieldMode.NULLABLE, output_type: dtypes.string}, - "field_c_name": {"mode": BigQueryClient.FieldMode.REQUIRED, output_type: dtypes.string} - "field_d_name": {"mode": BigQueryClient.FieldMode.REPEATED, output_type: dtypes.string} + "field_a_name": {"mode": BigQueryClient.FieldMode.REPEATED, "output_type": dtypes.int64}, + "field_b_name": {"mode": BigQueryClient.FieldMode.NULLABLE, "output_type": dtypes.string, "default_value": ""}, + "field_c_name": {"mode": BigQueryClient.FieldMode.REQUIRED, "output_type": dtypes.string} + "field_d_name": {"mode": BigQueryClient.FieldMode.REPEATED, "output_type": dtypes.string} } requested_streams=2, row_restriction="num_characters > 1000", @@ -171,8 +180,8 @@ def main(): print("row %d: %s" % (row_index, row)) row_index += 1 -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + main() ``` Then each field of a repeated column becomes a rank-1 variable length Tensor. If you want to diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 87c516938..04de70c81 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -168,6 +168,7 @@ cc_library( "@arrow", "@avro", "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/types:any", "@com_google_googleapis//google/cloud/bigquery/storage/v1beta1:storage_cc_grpc", "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", @@ -190,6 +191,7 @@ cc_library( "@com_google_absl//absl/algorithm", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:any", "@com_google_absl//absl/types:variant", "@com_google_googleapis//google/cloud/bigquery/storage/v1beta1:storage_cc_grpc", "@local_config_tf//:libtensorflow_framework", diff --git a/tensorflow_io/core/kernels/bigquery/bigquery_dataset_op.cc b/tensorflow_io/core/kernels/bigquery/bigquery_dataset_op.cc index d6b47b31b..185da4e3f 100644 --- a/tensorflow_io/core/kernels/bigquery/bigquery_dataset_op.cc +++ b/tensorflow_io/core/kernels/bigquery/bigquery_dataset_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/types/any.h" #include "arrow/buffer.h" #include "arrow/ipc/api.h" #include "tensorflow/core/framework/op_kernel.h" @@ -30,6 +31,7 @@ class BigQueryDatasetOp : public DatasetOpKernel { explicit BigQueryDatasetOp(OpKernelConstruction *ctx) : DatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("selected_fields", &selected_fields_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("default_values", &default_values_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_)); string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); @@ -54,20 +56,53 @@ class BigQueryDatasetOp : public DatasetOpKernel { output_shapes.reserve(num_outputs); DataTypeVector output_types_vector; output_types_vector.reserve(num_outputs); + typed_default_values_.reserve(num_outputs); for (uint64 i = 0; i < num_outputs; ++i) { output_shapes.push_back({}); output_types_vector.push_back(output_types_[i]); + const DataType &output_type = output_types_[i]; + const string &default_value = default_values_[i]; + switch (output_type) { + case DT_FLOAT: + typed_default_values_.push_back(absl::any(std::stof(default_value))); + break; + case DT_DOUBLE: + typed_default_values_.push_back(absl::any(std::stod(default_value))); + break; + case DT_INT32: + int32_t value_int32_t; + strings::safe_strto32(default_value, &value_int32_t); + typed_default_values_.push_back(absl::any(value_int32_t)); + break; + case DT_INT64: + int64_t value_int64_t; + strings::safe_strto64(default_value, &value_int64_t); + typed_default_values_.push_back(absl::any(value_int64_t)); + break; + case DT_BOOL: + typed_default_values_.push_back(absl::any(default_value == "True")); + break; + case DT_STRING: + typed_default_values_.push_back(absl::any(default_value)); + break; + default: + ctx->CtxFailure( + errors::InvalidArgument("Unsupported output_type:", output_type)); + break; + } } *output = new Dataset(ctx, client_resource, output_types_vector, std::move(output_shapes), std::move(stream), std::move(schema), selected_fields_, output_types_, - offset_, data_format_); + typed_default_values_, offset_, data_format_); } private: std::vector selected_fields_; std::vector output_types_; + std::vector default_values_; + std::vector typed_default_values_; int64 offset_; apiv1beta1::DataFormat data_format_; @@ -79,7 +114,8 @@ class BigQueryDatasetOp : public DatasetOpKernel { std::vector output_shapes, string stream, string schema, std::vector selected_fields, - std::vector output_types, int64 offset_, + std::vector output_types, + std::vector typed_default_values, int64 offset_, apiv1beta1::DataFormat data_format) : DatasetBase(DatasetContext(ctx)), client_resource_(client_resource), @@ -88,6 +124,7 @@ class BigQueryDatasetOp : public DatasetOpKernel { stream_(stream), selected_fields_(selected_fields), output_types_(output_types), + typed_default_values_(typed_default_values), offset_(offset_), avro_schema_(absl::make_unique()), data_format_(data_format) { @@ -147,6 +184,10 @@ class BigQueryDatasetOp : public DatasetOpKernel { const std::vector &output_types() const { return output_types_; } + const std::vector &typed_default_values() const { + return typed_default_values_; + } + const std::unique_ptr &avro_schema() const { return avro_schema_; } @@ -180,6 +221,7 @@ class BigQueryDatasetOp : public DatasetOpKernel { const string stream_; const std::vector selected_fields_; const std::vector output_types_; + const std::vector typed_default_values_; const std::unique_ptr avro_schema_; const int64 offset_; std::shared_ptr<::arrow::Schema> arrow_schema_; diff --git a/tensorflow_io/core/kernels/bigquery/bigquery_lib.h b/tensorflow_io/core/kernels/bigquery/bigquery_lib.h index d0aab7f95..68a80a3fd 100644 --- a/tensorflow_io/core/kernels/bigquery/bigquery_lib.h +++ b/tensorflow_io/core/kernels/bigquery/bigquery_lib.h @@ -26,6 +26,7 @@ limitations under the License. #include #undef OPTIONAL #endif +#include "absl/types/any.h" #include "api/Compiler.hh" #include "api/DataFile.hh" #include "api/Decoder.hh" @@ -127,7 +128,8 @@ class BigQueryReaderDatasetIteratorBase : public DatasetIterator { auto status = ReadRecord(ctx, out_tensors, this->dataset()->selected_fields(), - this->dataset()->output_types()); + this->dataset()->output_types(), + this->dataset()->typed_default_values()); current_row_index_++; return status; } @@ -181,10 +183,11 @@ class BigQueryReaderDatasetIteratorBase : public DatasetIterator { } virtual Status EnsureHasRow(bool *end_of_sequence) = 0; - virtual Status ReadRecord(IteratorContext *ctx, - std::vector *out_tensors, - const std::vector &columns, - const std::vector &output_types) = 0; + virtual Status ReadRecord( + IteratorContext *ctx, std::vector *out_tensors, + const std::vector &columns, + const std::vector &output_types, + const std::vector &typed_default_values) = 0; int current_row_index_ = 0; mutex mu_; std::unique_ptr<::grpc::ClientContext> read_rows_context_ TF_GUARDED_BY(mu_); @@ -245,7 +248,8 @@ class BigQueryReaderArrowDatasetIterator Status ReadRecord(IteratorContext *ctx, std::vector *out_tensors, const std::vector &columns, - const std::vector &output_types) + const std::vector &output_types, + const std::vector &typed_default_values) TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override { out_tensors->clear(); out_tensors->reserve(columns.size()); @@ -253,7 +257,6 @@ class BigQueryReaderArrowDatasetIterator if (this->current_row_index_ == 0 && this->column_indices_.empty()) { this->column_indices_.resize(columns.size()); for (size_t i = 0; i < columns.size(); ++i) { - DataType output_type = output_types[i]; auto column_name = this->record_batch_->column_name(i); auto it = std::find(columns.begin(), columns.end(), column_name); if (it == columns.end()) { @@ -337,7 +340,8 @@ class BigQueryReaderAvroDatasetIterator Status ReadRecord(IteratorContext *ctx, std::vector *out_tensors, const std::vector &columns, - const std::vector &output_types) + const std::vector &output_types, + const std::vector &typed_default_values) TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override { avro::decode(*this->decoder_, *this->datum_); if (this->datum_->type() != avro::AVRO_RECORD) { @@ -521,22 +525,28 @@ class BigQueryReaderAvroDatasetIterator case avro::AVRO_NULL: switch (output_types[i]) { case DT_BOOL: - ((*out_tensors)[i]).scalar()() = false; + ((*out_tensors)[i]).scalar()() = + absl::any_cast(typed_default_values[i]); break; case DT_INT32: - ((*out_tensors)[i]).scalar()() = 0; + ((*out_tensors)[i]).scalar()() = + absl::any_cast(typed_default_values[i]); break; case DT_INT64: - ((*out_tensors)[i]).scalar()() = 0l; + ((*out_tensors)[i]).scalar()() = + absl::any_cast(typed_default_values[i]); break; case DT_FLOAT: - ((*out_tensors)[i]).scalar()() = 0.0f; + ((*out_tensors)[i]).scalar()() = + absl::any_cast(typed_default_values[i]); break; case DT_DOUBLE: - ((*out_tensors)[i]).scalar()() = 0.0; + ((*out_tensors)[i]).scalar()() = + absl::any_cast(typed_default_values[i]); break; case DT_STRING: - ((*out_tensors)[i]).scalar()() = ""; + ((*out_tensors)[i]).scalar()() = + absl::any_cast(typed_default_values[i]); break; default: return errors::InvalidArgument( diff --git a/tensorflow_io/core/ops/bigquery_ops.cc b/tensorflow_io/core/ops/bigquery_ops.cc index 0c239699d..9d2c1d20d 100644 --- a/tensorflow_io/core/ops/bigquery_ops.cc +++ b/tensorflow_io/core/ops/bigquery_ops.cc @@ -32,6 +32,7 @@ REGISTER_OP("IO>BigQueryReadSession") .Attr("dataset_id: string") .Attr("selected_fields: list(string) >= 1") .Attr("output_types: list(type) >= 1") + .Attr("default_values: list(string) >= 1") .Attr("requested_streams: int") .Attr("data_format: string") .Attr("row_restriction: string = ''") @@ -53,6 +54,7 @@ REGISTER_OP("IO>BigQueryDataset") .Attr("data_format: string") .Attr("selected_fields: list(string) >= 1") .Attr("output_types: list(type) >= 1") + .Attr("default_values: list(string) >= 1") .Output("handle: variant") .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. diff --git a/tensorflow_io/python/ops/bigquery_dataset_ops.py b/tensorflow_io/python/ops/bigquery_dataset_ops.py index bec2e3ad3..fd8ce0401 100644 --- a/tensorflow_io/python/ops/bigquery_dataset_ops.py +++ b/tensorflow_io/python/ops/bigquery_dataset_ops.py @@ -68,6 +68,7 @@ def read_session( dataset_id, selected_fields, output_types=None, + default_values=None, row_restriction="", requested_streams=1, data_format: DataFormat = DataFormat.AVRO, @@ -84,10 +85,10 @@ def read_session( selected_fields: This can be a list or a dict. If a list, it has names of the fields in the table that should be read. If a dict, it should be in a form like, i.e: - { "field_a_name": {"mode": "repeated", output_type: dtypes.int64}, - "field_b_name": {"mode": "nullable", output_type: dtypes.string}, + { "field_a_name": {"mode": "repeated", "output_type": dtypes.int64}, + "field_b_name": {"mode": "nullable", "output_type": dtypes.int32, "default_value": 0}, ... - "field_x_name": {"mode": "repeated", output_type: dtypes.string} + "field_x_name": {"mode": "repeated", "output_type": dtypes.string, "default_value": ""} } "mode" is BigQuery column attribute, it can be 'repeated', 'nullable' or 'required'. The output field order is unrelated to the order of fields in @@ -98,6 +99,10 @@ def read_session( if selected_fields is a dictionary, this output_types information is included in selected_fields as described above. If not specified, DT_STRING is implied for all Tensors. + default_values: Default values to use when underlying tensor is "null" + in the same sequence as selected_fields. If not sepecified, + meaningful defaults are going to be used + (0 for numerices, empty string for strings, and False for booleans). row_restriction: Optional. SQL text filtering statement, similar to a WHERE clause in a query. requested_streams: Desirable number of streams that can be read in parallel. @@ -135,27 +140,45 @@ def read_session( raise ValueError("`requested_streams` must be a positive number") if isinstance(selected_fields, list): - if not isinstance(output_types, list): - raise ValueError( - "`output_types` must be a list if selected_fields is list" - ) - if output_types and len(output_types) != len(selected_fields): - raise ValueError( - "lengths of `output_types` must be a same as the " - "length of `selected_fields`" - ) - if not output_types: + if output_types is None: + if not isinstance(output_types, list): + raise ValueError( + "`output_types` must be a list if selected_fields is list" + ) + if len(output_types) != len(selected_fields): + raise ValueError( + "length of `output_types` must be a same as the " + "length of `selected_fields`" + ) output_types = [dtypes.string] * len(selected_fields) - # Repeated field is not supported if selected_fields is list + # Repeated fields are not supported if selected_fields is list selected_fields_repeated = [False] * len(selected_fields) + if default_values is None: + default_values = [] + for output_type in output_types: + default_values.append(self._get_default_value_for_type(output_type)) + else: + if not isinstance(default_values, list): + raise ValueError( + "`default_values` must be a list if selected_fields is list" + ) + if len(default_values) != len(selected_fields): + raise ValueError( + "length of `default_values` must be a same as the " + "length of `selected_fields`" + ) + default_values = [ + str(default_value) for default_value in default_values + ] elif isinstance(selected_fields, dict): _selected_fields = [] selected_fields_repeated = [] output_types = [] - for field in selected_fields: + default_values = [] + for field, field_attr_dict in selected_fields.items(): _selected_fields.append(field) - mode = selected_fields[field].get("mode", self.FieldMode.NULLABLE) + mode = field_attr_dict.get("mode", self.FieldMode.NULLABLE) if mode == self.FieldMode.REPEATED: selected_fields_repeated.append(True) elif mode == self.FieldMode.NULLABLE or mode == self.FieldMode.REQUIRED: @@ -164,9 +187,13 @@ def read_session( raise ValueError( "mode needs be BigQueryClient.FieldMode.NULLABLE, FieldMode.REQUIRED or FieldMode.REPEATED" ) - output_types.append( - selected_fields[field].get("output_type", dtypes.string) - ) + output_type = field_attr_dict.get("output_type", dtypes.string) + output_types.append(output_type) + if "default_value" in field_attr_dict: + default_value = str(field_attr_dict["default_value"]) + else: + default_value = self._get_default_value_for_type(output_type) + default_values.append(default_value) selected_fields = _selected_fields else: raise ValueError("`selected_fields` must be a list or dict.") @@ -181,6 +208,7 @@ def read_session( data_format=data_format.value, selected_fields=selected_fields, output_types=output_types, + default_values=default_values, row_restriction=row_restriction, ) return BigQueryReadSession( @@ -191,6 +219,7 @@ def read_session( selected_fields, selected_fields_repeated, output_types, + default_values, row_restriction, requested_streams, data_format, @@ -199,6 +228,12 @@ def read_session( self._client_resource, ) + def _get_default_value_for_type(self, output_type): + if output_type == tf.string: + return "" + else: + return str(output_type.as_numpy_dtype()) + class BigQueryReadSession: """Entry point for reading data from Cloud BigQuery.""" @@ -212,6 +247,7 @@ def __init__( selected_fields, selected_fields_repeated, output_types, + default_values, row_restriction, requested_streams, data_format, @@ -226,6 +262,7 @@ def __init__( self._selected_fields = selected_fields self._selected_fields_repeated = selected_fields_repeated self._output_types = output_types + self._default_values = default_values self._row_restriction = row_restriction self._requested_streams = requested_streams self._data_format = data_format @@ -259,6 +296,7 @@ def read_rows(self, stream, offset=0): self._selected_fields, self._selected_fields_repeated, self._output_types, + self._default_values, self._schema, self._data_format, stream, @@ -325,6 +363,7 @@ def __init__( selected_fields, selected_fields_repeated, output_types, + default_values, schema, data_format, stream, @@ -333,15 +372,18 @@ def __init__( # selected_fields and corresponding output_types have to be sorted because # of b/141251314 sorted_fields_with_types = sorted( - zip(selected_fields, selected_fields_repeated, output_types), + zip( + selected_fields, selected_fields_repeated, output_types, default_values + ), key=itemgetter(0), ) - selected_fields, selected_fields_repeated, output_types = list( + selected_fields, selected_fields_repeated, output_types, default_values = list( zip(*sorted_fields_with_types) ) selected_fields = list(selected_fields) selected_fields_repeated = list(selected_fields_repeated) output_types = list(output_types) + default_values = list(default_values) tensor_shapes = list( [ @@ -366,6 +408,7 @@ def __init__( client=client_resource, selected_fields=selected_fields, output_types=output_types, + default_values=default_values, schema=schema, data_format=data_format.value, stream=stream, diff --git a/tests/test_bigquery.py b/tests/test_bigquery.py index a89c8b1d2..5765e70de 100644 --- a/tests/test_bigquery.py +++ b/tests/test_bigquery.py @@ -288,12 +288,114 @@ class BigqueryOpsTest(test.TestCase): "repeated_float": [1000.0, 700.0, 1200.0], "repeated_double": [101.0, 10.1, 0.3, 1.4], "repeated_string": ["string1", "string2", "string3", "string4"], - "repeated_string": [ - "string1", - "string2", - "string3", - "string4", - ], + "repeated_string": ["string1", "string2", "string3", "string4"], + "repeated_double": [101.0, 10.1, 0.3, 1.4], + } + + SELECTED_FIELDS_LIST = [ + "string", + "boolean", + "int", + "long", + "float", + "double", + ] + + OUTPUT_TYPES_LIST = [ + dtypes.string, + dtypes.bool, + dtypes.int32, + dtypes.int64, + dtypes.float32, + dtypes.float64, + ] + + SELECTED_FIELDS_DICT = { + "string": {"output_type": dtypes.string}, + "boolean": {"output_type": dtypes.bool}, + "int": {"output_type": dtypes.int32}, + "long": {"output_type": dtypes.int64}, + "float": {"output_type": dtypes.float32}, + "double": {"output_type": dtypes.float64}, + "repeated_bool": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.bool, + }, + "repeated_int": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.int32, + }, + "repeated_long": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.int64, + }, + "repeated_float": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.float32, + }, + "repeated_double": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.float64, + }, + "repeated_string": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.string, + }, + } + + SELECTED_FIELDS_DICT_WITH_DEFAULTS = { + "string": {"output_type": dtypes.string, "default_value": "abc"}, + "boolean": {"output_type": dtypes.bool, "default_value": True}, + "int": {"output_type": dtypes.int32, "default_value": 10}, + "long": {"output_type": dtypes.int64, "default_value": 100}, + "float": {"output_type": dtypes.float32, "default_value": 100.0}, + "double": {"output_type": dtypes.float64, "default_value": 1000.0}, + "repeated_bool": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.bool, + "default_value": True, + }, + "repeated_int": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.int32, + "default_value": -10, + }, + "repeated_long": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.int64, + "default_value": -100, + }, + "repeated_float": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.float32, + "default_value": -1000.01, + }, + "repeated_double": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.float64, + "default_value": -1000.001, + }, + "repeated_string": { + "mode": BigQueryClient.FieldMode.REPEATED, + "output_type": dtypes.string, + "default_value": "def", + }, + } + + CUSTOM_DEFAULT_VALUES = { + "boolean": True, + "double": 1000.0, + "float": 100.0, + "int": 10, + "long": 100, + "string": "abc", + "repeated_bool": [False, True, True], + "repeated_int": [30, 40, 20], + "repeated_long": [200, 300, 900], + "repeated_float": [1000.0, 700.0, 1200.0], + "repeated_double": [101.0, 10.1, 0.3, 1.4], + "repeated_string": ["string1", "string2", "string3", "string4"], + "repeated_string": ["string1", "string2", "string3", "string4"], "repeated_double": [101.0, 10.1, 0.3, 1.4], } @@ -336,71 +438,18 @@ def tearDownClass(cls): # pylint: disable=invalid-name """setUpClass""" cls.server.stop() - def _get_read_session(self, client, nonrepeated_only=False): - if nonrepeated_only: - return client.read_session( - self.PARENT, - self.GCP_PROJECT_ID, - self.TABLE_ID, - self.DATASET_ID, - selected_fields=[ - "string", - "boolean", - "int", - "long", - "float", - "double", - ], - output_types=[ - dtypes.string, - dtypes.bool, - dtypes.int32, - dtypes.int64, - dtypes.float32, - dtypes.float64, - ], - requested_streams=2, - ) - else: - return client.read_session( - self.PARENT, - self.GCP_PROJECT_ID, - self.TABLE_ID, - self.DATASET_ID, - selected_fields={ - "string": {"output_type": dtypes.string}, - "boolean": {"output_type": dtypes.bool}, - "int": {"output_type": dtypes.int32}, - "long": {"output_type": dtypes.int64}, - "float": {"output_type": dtypes.float32}, - "double": {"output_type": dtypes.float64}, - "repeated_bool": { - "mode": BigQueryClient.FieldMode.REPEATED, - "output_type": dtypes.bool, - }, - "repeated_int": { - "mode": BigQueryClient.FieldMode.REPEATED, - "output_type": dtypes.int32, - }, - "repeated_long": { - "mode": BigQueryClient.FieldMode.REPEATED, - "output_type": dtypes.int64, - }, - "repeated_float": { - "mode": BigQueryClient.FieldMode.REPEATED, - "output_type": dtypes.float32, - }, - "repeated_double": { - "mode": BigQueryClient.FieldMode.REPEATED, - "output_type": dtypes.float64, - }, - "repeated_string": { - "mode": BigQueryClient.FieldMode.REPEATED, - "output_type": dtypes.string, - }, - }, - requested_streams=2, - ) + def _get_read_session( + self, client, selected_fields, output_types=None, requested_streams=2 + ): + return client.read_session( + self.PARENT, + self.GCP_PROJECT_ID, + self.TABLE_ID, + self.DATASET_ID, + selected_fields=selected_fields, + output_types=output_types, + requested_streams=2, + ) def test_fake_server(self): """Fake server test.""" @@ -444,7 +493,9 @@ def test_fake_server(self): def test_read_rows(self): """Test for reading rows.""" client = BigQueryTestClient(BigqueryOpsTest.server.endpoint()) - read_session = self._get_read_session(client) + read_session = self._get_read_session( + client, selected_fields=self.SELECTED_FIELDS_DICT + ) streams_list = read_session.get_streams() self.assertEqual(len(streams_list), 2) @@ -470,10 +521,35 @@ def test_read_rows(self): with self.assertRaises(errors.OutOfRangeError): itr2.get_next() + def test_read_rows_default_values(self): + """Test for reading rows when default values are specified.""" + client = BigQueryTestClient(BigqueryOpsTest.server.endpoint()) + + read_session = self._get_read_session( + client, selected_fields=self.SELECTED_FIELDS_DICT_WITH_DEFAULTS + ) + + streams_list = read_session.get_streams() + self.assertEqual(len(streams_list), 2) + dataset2 = read_session.read_rows(streams_list[1]) + itr2 = iter(dataset2) + self.assertEqual( + self.STREAM_2_ROWS[0], self._normalize_dictionary(itr2.get_next()) + ) + self.assertEqual( + self.CUSTOM_DEFAULT_VALUES, self._normalize_dictionary(itr2.get_next()) + ) + with self.assertRaises(errors.OutOfRangeError): + itr2.get_next() + def test_read_rows_nonrepeated_only(self): """Test for reading rows with non-repeated fields only, then selected_fields and output_types are list (backward compatible).""" client = BigQueryTestClient(BigqueryOpsTest.server.endpoint()) - read_session = self._get_read_session(client, nonrepeated_only=True) + read_session = self._get_read_session( + client, + selected_fields=self.SELECTED_FIELDS_LIST, + output_types=self.OUTPUT_TYPES_LIST, + ) streams_list = read_session.get_streams() self.assertEqual(len(streams_list), 2) @@ -506,7 +582,9 @@ def test_read_rows_nonrepeated_only(self): def test_read_rows_with_offset(self): """Test for reading rows with offset.""" client = BigQueryTestClient(BigqueryOpsTest.server.endpoint()) - read_session = self._get_read_session(client) + read_session = self._get_read_session( + client, selected_fields=self.SELECTED_FIELDS_DICT + ) streams_list = read_session.get_streams() self.assertEqual(len(streams_list), 2) @@ -521,7 +599,9 @@ def test_read_rows_with_offset(self): def test_parallel_read_rows(self): """Test for reading rows in parallel.""" client = BigQueryTestClient(BigqueryOpsTest.server.endpoint()) - read_session = self._get_read_session(client) + read_session = self._get_read_session( + client, selected_fields=self.SELECTED_FIELDS_DICT + ) dataset = read_session.parallel_read_rows() itr = iter(dataset)