Skip to content

Commit

Permalink
Merge pull request #9 from rcrowe-google/Pratishtha/feature/component…
Browse files Browse the repository at this point in the history
…-spec

defining ComponentSpec class
  • Loading branch information
pratishtha-abrol committed Jun 17, 2021
2 parents 23f931d + 1108e43 commit 107acae
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 87 deletions.
Binary file added component/__pycache__/component.cpython-38.pyc
Binary file not shown.
Binary file added component/__pycache__/executor.cpython-38.pyc
Binary file not shown.
81 changes: 44 additions & 37 deletions component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,84 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example of a Hello World TFX custom component.
This custom component simply reads tf.Examples from input and passes through as
output. This is meant to serve as a kind of starting point example for creating
custom components.
This component along with other custom component related code will only serve as
an example and will not be supported by TFX team.
"""TFX Schema Curation Custom Component
"""

from typing import Optional, Text
from typing import List, Optional, Text

from tfx import types
from tfx.dsl.components.base import base_component
from tfx.dsl.components.base import executor_spec
from tfx.examples.custom_components.hello_world.hello_component import executor
from tfx.types import channel_utils
from tfx.types import standard_artifacts
from tfx.types.component_spec import ChannelParameter
from tfx.types.component_spec import ExecutionParameter
from tfx.utils import json_utils

import executor

class HelloComponentSpec(types.ComponentSpec):
"""ComponentSpec for Custom TFX Hello World Component."""
class SchemaCurationSpec(types.ComponentSpec):
"""ComponentSpec for TFX Schema Curation Custom Component."""

PARAMETERS = {
# These are parameters that will be passed in the call to
# create an instance of this component.
'name': ExecutionParameter(type=Text),
'module_file': ExecutionParameter(type=str, optional=True),
'module_path': ExecutionParameter(type=str, optional=True),
'preprocessing_fn': ExecutionParameter(type=str, optional=True),
'exclude_splits': ExecutionParameter(type=str, optional=True),
}
INPUTS = {
# This will be a dictionary with input artifacts, including URIs
'input_data': ChannelParameter(type=standard_artifacts.Examples),
'statistics': ChannelParameter(type=standard_artifacts.ExampleStatistics),
'input_schema': ChannelParameter(type=standard_artifacts.Schema), # Dictionary obtained as output from SchemaGen
}
OUTPUTS = {
# This will be a dictionary which this component will populate
'output_data': ChannelParameter(type=standard_artifacts.Examples),
'output_schema': ChannelParameter(type=standard_artifacts.Schema), # Dictionary which containes new schema
}


class HelloComponent(base_component.BaseComponent):
"""Custom TFX Hello World Component.
class SchemaCuration(base_component.BaseComponent):
"""Custom TFX Schema Curation Component.
The SchemaCuration component is used to apply user code to a chema generated by SchemaGen
in order to curate the schema based on domain knowledge.
This custom component class consists of only a constructor.
Component `outputs` contains:
- `output_schema`: Channel of type `standard_artifact.Schema`
Current progress :
- Accepts schema, outputs the same schema
"""

SPEC_CLASS = HelloComponentSpec
SPEC_CLASS = SchemaCurationSpec
EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

def __init__(self,
input_data: types.Channel = None,
output_data: types.Channel = None,
name: Optional[Text] = None):
"""Construct a HelloComponent.
statistics: types.Channel,
input_schema: types.Channel,
exclude_splits: Optional[List[Text]] = None):
"""Construct a SchemaCurationComponent.
Args:
input_data: A Channel of type `standard_artifacts.Examples`. This will
often contain two splits: 'train', and 'eval'.
output_data: A Channel of type `standard_artifacts.Examples`. This will
usually contain the same splits as input_data.
name: Optional unique name. Necessary if multiple Hello components are
input_schema: A dictionary that containes the schema generated by SchemaGen component of tfx
output_schema: A dictionary that contains the schema after curation by the custom schema curation component
name: Optional unique name. Necessary if multiple custom schema curation components are
declared in the same pipeline.
"""
# output_data will contain a list of Channels for each split of the data,
# by default a 'train' split and an 'eval' split. Since HelloComponent
# passes the input data through to output, the splits in output_data will
# be the same as the splits in input_data, which were generated by the
# upstream component.
if not output_data:
output_data = channel_utils.as_channel([standard_artifacts.Examples()])
# if not output_schema:
# output_schema = channel_utils.as_channel([standard_artifacts.Examples()])

if exclude_splits is None:
exclude_splits = []
logging.info('Excluding no splits because exclude_splits is not set.')

output_schema = types.Channel(type=standard_artifacts.Schema)

spec = HelloComponentSpec(input_data=input_data,
output_data=output_data, name=name)
super(HelloComponent, self).__init__(spec=spec)
spec = SchemaCurationSpec(statistics=statistics,
input_schema=input_schema,
exclude_splits=json_utils.dumps(exclude_splits),
output_schema=output_schema)
super(SchemaCuration, self).__init__(spec=spec)
41 changes: 18 additions & 23 deletions component/component_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for HelloComponent."""
"""Tests for TFX Schema Curation Custom Component."""

import json

import tensorflow as tf

from tfx.examples.custom_components.hello_world.hello_component import component
from tfx.types import artifact
from tfx.types import channel_utils
from tfx.types import standard_artifacts
from tfx.types import standard_component_specs
from tfx.types import artifact_utils
import component


class HelloComponentTest(tf.test.TestCase):

def setUp(self):
super(HelloComponentTest, self).setUp()
self.name = 'HelloWorld'
class SchemaCurationTest(tf.test.TestCase):

def testConstruct(self):
input_data = standard_artifacts.Examples()
input_data.split_names = json.dumps(artifact.DEFAULT_EXAMPLE_SPLITS)
output_data = standard_artifacts.Examples()
output_data.split_names = json.dumps(artifact.DEFAULT_EXAMPLE_SPLITS)
this_component = component.HelloComponent(
input_data=channel_utils.as_channel([input_data]),
output_data=channel_utils.as_channel([output_data]),
name=u'Testing123')
self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
this_component.outputs['output_data'].type_name)
artifact_collection = this_component.outputs['output_data'].get()
for artifacts in artifact_collection:
split_list = json.loads(artifacts.split_names)
self.assertEqual(artifact.DEFAULT_EXAMPLE_SPLITS.sort(),
split_list.sort())

statistics_artifact = standard_artifacts.ExampleStatistics()
statistics_artifact.split_names = artifact_utils.encode_split_names(
['train', 'eval'])
exclude_splits = []
schema_curation = component.SchemaCuration(
statistics=channel_utils.as_channel([statistics_artifact]),
input_schema=channel_utils.as_channel([standard_artifacts.Schema()]),
exclude_splits=exclude_splits,
)
self.assertEqual(
standard_artifacts.Schema.TYPE_NAME,
schema_curation.outputs['output_schema'].type_name)


if __name__ == '__main__':
tf.test.main()
67 changes: 40 additions & 27 deletions component/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example of a Hello World TFX custom component.
This custom component simply passes examples through. This is meant to serve as
a kind of starting point example for creating custom components.
This component along with other custom component related code will only serve as
an example and will not be supported by TFX team.
"""Executor code for TFX Schema Curation Custom Component
"""

import json
Expand All @@ -32,9 +26,10 @@
from tfx.types import artifact_utils
from tfx.utils import io_utils

_DEFAULT_FILE_NAME = 'schema.pbtxt'

class Executor(base_executor.BaseExecutor):
"""Executor for HelloComponent."""
"""Executor for TFX Schema Curation Custom Component."""

def Do(self, input_dict: Dict[Text, List[types.Artifact]],
output_dict: Dict[Text, List[types.Artifact]],
Expand Down Expand Up @@ -67,22 +62,40 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
"""
self._log_startup(input_dict, output_dict, exec_properties)

input_artifact = artifact_utils.get_single_instance(
input_dict['input_data'])
output_artifact = artifact_utils.get_single_instance(
output_dict['output_data'])
output_artifact.split_names = input_artifact.split_names

split_to_instance = {}

for split in json.loads(input_artifact.split_names):
uri = artifact_utils.get_split_uri([input_artifact], split)
split_to_instance[split] = uri

for split, instance in split_to_instance.items():
input_dir = instance
output_dir = artifact_utils.get_split_uri([output_artifact], split)
for filename in fileio.listdir(input_dir):
input_uri = os.path.join(input_dir, filename)
output_uri = os.path.join(output_dir, filename)
io_utils.copy_file(src=input_uri, dst=output_uri, overwrite=True)
# Load and deserialize exclude splits from execution properties.
exclude_splits = json_utils.loads(
exec_properties.get('exclude_splits', 'null')) or []
if not isinstance(exclude_splits, list):
raise ValueError('exclude_splits in execution properties needs to be a '
'list. Got %s instead.' % type(exclude_splits))

# Setup output splits.
stats_artifact = artifact_utils.get_single_instance(
input_dict['statistics'])
stats_split_names = artifact_utils.decode_split_names(
stats_artifact.split_names)
split_names = [
split for split in stats_split_names if split not in exclude_splits
]

schema = io_utils.SchemaReader().read(
io_utils.get_only_uri_in_dir(
artifact_utils.get_single_uri(
input_dict['schema'])))

for split in artifact_utils.decode_split_names(stats_artifact.split_names):
if split in exclude_splits:
continue
logging.info(
'Curating schema against the computed statistics for '
'split %s.', split)

output_schema = schema

output_uri = os.path.join(
artifact_utils.get_single_uri(
output_dict['schema']),
_DEFAULT_FILE_NAME)
io_utils.write_pbtxt_file(output_uri, output_schema)
logging.info('Schema written to %s.', output_uri)

0 comments on commit 107acae

Please sign in to comment.