-
Notifications
You must be signed in to change notification settings - Fork 706
/
executor.py
88 lines (72 loc) · 3.13 KB
/
executor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic TFX BigQueryExampleGen executor."""
import json
from typing import Any, Dict, Optional
import apache_beam as beam
from google.cloud import bigquery
import tensorflow as tf
from tfx.components.example_gen import base_example_gen_executor
from tfx.extensions.google_cloud_big_query import utils
class _BigQueryConverter:
"""Help class for bigquery result row to tf example conversion."""
def __init__(self, query: str, project_id: Optional[str] = None):
"""Instantiate a _BigQueryConverter object.
Args:
query: the query statement to get the type information.
project_id: optional. The GCP project ID to run the query job. Default to
the GCP project ID set by the gcloud environment on the machine.
"""
client = bigquery.Client(project=project_id)
# Dummy query to get the type information for each field.
query_job = client.query('SELECT * FROM ({}) LIMIT 0'.format(query))
results = query_job.result()
self._type_map = {}
for field in results.schema:
self._type_map[field.name] = field.field_type
def RowToExample(self, instance: Dict[str, Any]) -> tf.train.Example:
"""Convert bigquery result row to tf example."""
return utils.row_to_example(self._type_map, instance)
@beam.ptransform_fn
@beam.typehints.with_input_types(beam.Pipeline)
@beam.typehints.with_output_types(tf.train.Example)
def _BigQueryToExample(pipeline: beam.Pipeline, exec_properties: Dict[str, Any],
split_pattern: str) -> beam.pvalue.PCollection:
"""Read from BigQuery and transform to TF examples.
Args:
pipeline: beam pipeline.
exec_properties: A dict of execution properties.
split_pattern: Split.pattern in Input config, a BigQuery sql string.
Returns:
PCollection of TF examples.
"""
project = utils.parse_gcp_project(exec_properties['_beam_pipeline_args'])
converter = _BigQueryConverter(split_pattern, project)
big_query_custom_config = None
if custom_config_str := exec_properties.get('custom_config'):
big_query_custom_config = json.loads(custom_config_str)
return (
pipeline
| 'QueryTable'
>> utils.ReadFromBigQuery(
query=split_pattern,
big_query_custom_config=big_query_custom_config,
)
| 'ToTFExample' >> beam.Map(converter.RowToExample)
)
class Executor(base_example_gen_executor.BaseExampleGenExecutor):
"""Generic TFX BigQueryExampleGen executor."""
def GetInputSourceToExamplePTransform(self) -> beam.PTransform:
"""Returns PTransform for BigQuery to TF examples."""
return _BigQueryToExample