Skip to content

Commit

Permalink
Introduces experimental flag to switch tfxios to byte size-based batc…
Browse files Browse the repository at this point in the history
…hing.

The byte size-based batching is disabled by default, but will be enabled after experiments.

PiperOrigin-RevId: 511526479
  • Loading branch information
iindyk authored and tfx-copybara committed Feb 22, 2023
1 parent fc9461b commit ce2e12b
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 70 deletions.
106 changes: 100 additions & 6 deletions tfx_bsl/coders/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,134 @@
"""Utilities for batching."""

import inspect
from typing import Optional
import math
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar
from absl import flags

import apache_beam as beam
from tfx_bsl.telemetry import util as telemetry_util

# Beam might grow the batch size too large for Arrow BinaryArray / ListArray
# to hold the contents (e.g. if the sum of the length of a string feature in
# a batch exceeds 2GB). Before the decoder can produce LargeBinaryArray /
# LargeListArray, we have to cap the batch size.
_BATCH_SIZE_CAP = 1000

# Experimental and will be removed in the future.
# Controls whether to delegate batch size tuning to `beam.BatchElements` or to
# batch records based on target size of the batch in bytes.
# TODO(b/266803710): Switch to byte size batching by default and clean this up.
_USE_BYTE_SIZE_BATCHING = flags.DEFINE_bool(
"tfxio_use_byte_size_batching",
False,
(
"By default input TFXIO sources will delegate tuning of the batch size "
"of input data to Beam. If this flag is set to True, the sources will "
"batch elements based on the target batch size in bytes."
),
)
# Batch size is determined by the target size in bytes, but not larger than the
# cap.
# Note that this upper bound in byte size applies to the sum of encoded records
# rather than the produced decoded batch itself. In most cases, however, the
# size of the latter is bounded above by the size of the former. Exception to
# this rule is a case when there are many empty features in the encoded
# examples, but even then the difference is not significant and it is likely
# that the actual size cap will be applied first.
_TARGET_BATCH_BYTES_SIZE = 104_857_600 # 100MB
_BATCH_SIZE_CAP_WITH_BYTE_TARGET = 10000

def GetBatchElementsKwargs(batch_size: Optional[int]):

def GetBatchElementsKwargs(
batch_size: Optional[int], element_size_fn: Callable[[Any], int] = len
) -> Dict[str, Any]:
"""Returns the kwargs to pass to beam.BatchElements()."""
if batch_size is not None:
return {
"min_batch_size": batch_size,
"max_batch_size": batch_size,
}
if _USE_BYTE_SIZE_BATCHING.value:
min_element_size = int(
math.ceil(_TARGET_BATCH_BYTES_SIZE / _BATCH_SIZE_CAP_WITH_BYTE_TARGET)
)
return {
"min_batch_size": _TARGET_BATCH_BYTES_SIZE,
"max_batch_size": _TARGET_BATCH_BYTES_SIZE,
"element_size_fn": lambda e: max(element_size_fn(e), min_element_size),
}
# Allow `BatchElements` to tune the values with the given parameters.
# We fix the tuning parameters here to prevent Beam changes from immediately
# affecting all dependencies.
result = {
"min_batch_size": 1,
"max_batch_size": _BATCH_SIZE_CAP,
"target_batch_overhead": 0.05,
"target_batch_duration_secs": 1,
"variance": 0.25,
}
# We fix the parameters here to prevent Beam changes from immediately
# affecting all dependencies.
# TODO(b/266803710): Clean this up after deciding on optimal batch_size
# selection logic.
batch_elements_signature = inspect.signature(beam.BatchElements)
if (
"target_batch_duration_secs_including_fixed_cost"
in batch_elements_signature.parameters
):
result["target_batch_duration_secs_including_fixed_cost"] = 1
return result


def _MakeAndIncrementBatchingMetrics(
unused_element: Any,
batch_size: Optional[int],
telemetry_descriptors: Optional[Sequence[str]],
) -> None:
"""Increments metrics relevant to batching."""
namespace = telemetry_util.MakeTfxNamespace(
telemetry_descriptors or ["Unknown"]
)
beam.metrics.Metrics.counter(namespace, "tfxio_use_byte_size_batching").inc(
int(_USE_BYTE_SIZE_BATCHING.value)
)
beam.metrics.Metrics.counter(namespace, "desired_batch_size").inc(
batch_size or 0
)


T = TypeVar("T")


@beam.ptransform_fn
@beam.typehints.with_input_types(T)
@beam.typehints.with_output_types(List[T])
def BatchRecords(
records: beam.PCollection,
batch_size: Optional[int],
telemetry_descriptors: Optional[Sequence[str]],
record_size_fn: Callable[[T], int] = len,
) -> beam.PCollection:
"""Batches collection of records tuning the batch size if not provided.
Args:
records: A PCollection of records to batch.
batch_size: Desired batch size. If None, will be tuned for optimal
performance.
telemetry_descriptors: Descriptors to use for batching metrics.
record_size_fn: Function used to determine size of each record in bytes.
Only used if byte size-based batching is enabled. Defaults to `len`
function suitable for bytes records.
Returns:
A PCollection of batched records.
"""
_ = (
records.pipeline
| "CreateSole" >> beam.Create([None])
| "IncrementMetrics"
>> beam.Map(
_MakeAndIncrementBatchingMetrics,
batch_size=batch_size,
telemetry_descriptors=telemetry_descriptors,
)
)
return records | "BatchElements" >> beam.BatchElements(
**GetBatchElementsKwargs(batch_size, record_size_fn)
)
164 changes: 148 additions & 16 deletions tfx_bsl/coders/batch_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,164 @@
# limitations under the License.
"""Tests for tfx_bsl.coders.batch_util."""

from tfx_bsl.coders import batch_util
from absl.testing import absltest
from absl.testing import flagsaver

import apache_beam as beam
from apache_beam.testing import util as beam_testing_util

class BatchUtilTest(absltest.TestCase):
from tfx_bsl.coders import batch_util
from absl.testing import absltest
from absl.testing import parameterized

def testGetBatchElementsKwargs(self):
kwargs = batch_util.GetBatchElementsKwargs(batch_size=None)
target_batch_duration_secs_including_fixed_cost = kwargs.pop(
"target_batch_duration_secs_including_fixed_cost", None
)
self.assertIn(target_batch_duration_secs_including_fixed_cost, {1, None})
self.assertDictEqual(
kwargs,
{
_BATCH_RECORDS_TEST_CASES = (
dict(
testcase_name="fixed_batch_size",
batch_size=5000,
tfxio_use_byte_size_batching=False,
expected_kwargs={"max_batch_size": 5000, "min_batch_size": 5000},
),
dict(
testcase_name="fixed_batch_size_byte_size_batching",
batch_size=5000,
tfxio_use_byte_size_batching=True,
expected_kwargs={"max_batch_size": 5000, "min_batch_size": 5000},
),
dict(
testcase_name="batch_size_none",
batch_size=None,
tfxio_use_byte_size_batching=False,
expected_kwargs={
"min_batch_size": 1,
"max_batch_size": 1000,
"target_batch_overhead": 0.05,
"target_batch_duration_secs": 1,
"variance": 0.25,
},
),
dict(
testcase_name="byte_size_batching",
batch_size=None,
tfxio_use_byte_size_batching=True,
expected_kwargs={
"min_batch_size": 104_857_600,
"max_batch_size": 104_857_600,
"element_size_fn": "dummy",
},
expected_element_contributions={
b"dummy": 10486, # Minimal contribution.
b"dummy" * 10000: 50000,
},
),
dict(
testcase_name="byte_size_batching_with_element_size_fn",
batch_size=None,
tfxio_use_byte_size_batching=True,
expected_kwargs={
"min_batch_size": 104_857_600,
"max_batch_size": 104_857_600,
"element_size_fn": "dummy",
},
element_size_fn=lambda kv: len(kv[0] or b"") + len(kv[1]),
expected_element_contributions={
(None, b"dummy"): 10486, # Minimal contribution.
(b"asd", b"dummy" * 10000): 50003,
},
),
)


class BatchUtilTest(parameterized.TestCase):

@parameterized.named_parameters(*_BATCH_RECORDS_TEST_CASES)
def testGetBatchElementsKwargs(
self,
batch_size,
tfxio_use_byte_size_batching,
expected_kwargs,
element_size_fn=len,
expected_element_contributions=None,
):
with flagsaver.flagsaver(
tfxio_use_byte_size_batching=tfxio_use_byte_size_batching
):
kwargs = batch_util.GetBatchElementsKwargs(
batch_size, element_size_fn=element_size_fn
)
# This parameter may not be present in some Beam versions that we support.
target_batch_duration_secs_including_fixed_cost = kwargs.pop(
"target_batch_duration_secs_including_fixed_cost", None
)
self.assertIn(target_batch_duration_secs_including_fixed_cost, {1, None})
if expected_kwargs.pop("element_size_fn", None) is not None:
self.assertIn("element_size_fn", kwargs)
element_size_fn = kwargs.pop("element_size_fn")
for (
element,
expected_contribution,
) in expected_element_contributions.items():
self.assertEqual(
element_size_fn(element),
expected_contribution,
msg=f"Unexpected contribution of element {element}",
)
self.assertDictEqual(kwargs, expected_kwargs)

@parameterized.named_parameters(*_BATCH_RECORDS_TEST_CASES)
def testBatchRecords(
self,
batch_size,
tfxio_use_byte_size_batching,
expected_kwargs,
element_size_fn=len,
expected_element_contributions=None,
):
del expected_kwargs
telemetry_descriptors = ["TestComponent"]
input_records = (
[b"asd", b"asds", b"123", b"gdgd" * 1000]
if expected_element_contributions is None
else expected_element_contributions.keys()
)
kwargs = batch_util.GetBatchElementsKwargs(batch_size=5000)
self.assertDictEqual(
kwargs, {"max_batch_size": 5000, "min_batch_size": 5000}
)

def AssertFn(batched_records):
# We can't validate the actual sizes since they depend on test
# environment.
self.assertNotEmpty(batched_records)
for batch in batched_records:
self.assertIsInstance(batch, list)
self.assertNotEmpty(batch)

with flagsaver.flagsaver(
tfxio_use_byte_size_batching=tfxio_use_byte_size_batching
):
p = beam.Pipeline()
batched_records_pcoll = (
p
| beam.Create(input_records)
| batch_util.BatchRecords(
batch_size, telemetry_descriptors, record_size_fn=element_size_fn
)
)
beam_testing_util.assert_that(batched_records_pcoll, AssertFn)
pipeline_result = p.run()
pipeline_result.wait_until_finish()
all_metrics = pipeline_result.metrics()
maintained_metrics = all_metrics.query(
beam.metrics.metric.MetricsFilter().with_namespace(
"tfx." + ".".join(telemetry_descriptors)
)
)
self.assertIsNotNone(maintained_metrics)
counters = maintained_metrics[beam.metrics.metric.MetricResults.COUNTERS]
self.assertLen(counters, 2)
expected_counters = {
"tfxio_use_byte_size_batching": int(tfxio_use_byte_size_batching),
"desired_batch_size": batch_size or 0,
}
for counter in counters:
self.assertEqual(
counter.result, expected_counters[counter.key.metric.name]
)


if __name__ == "__main__":
Expand Down
30 changes: 20 additions & 10 deletions tfx_bsl/coders/csv_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,26 @@ def CSVToRecordBatch(lines: beam.pvalue.PCollection,
secondary_delimiter=secondary_delimiter)))

# Do second pass to generate the RecordBatches.
return (csv_lines_and_raw_records
| "BatchCSVLines" >> beam.BatchElements(
**batch_util.GetBatchElementsKwargs(desired_batch_size))
| "BatchedCSVRowsToArrow" >> beam.ParDo(
BatchedCSVRowsToRecordBatch(
skip_blank_lines=skip_blank_lines,
multivalent_columns=multivalent_columns,
secondary_delimiter=secondary_delimiter,
raw_record_column_name=raw_record_column_name),
column_infos))
return (
csv_lines_and_raw_records
| "BatchCSVLines"
>> batch_util.BatchRecords(
desired_batch_size,
telemetry_descriptors=["CSVToRecordBatch"],
# The elements are tuples of parsed and unparsed CSVlines.
record_size_fn=lambda kv: len(kv[1]) << 1,
)
| "BatchedCSVRowsToArrow"
>> beam.ParDo(
BatchedCSVRowsToRecordBatch(
skip_blank_lines=skip_blank_lines,
multivalent_columns=multivalent_columns,
secondary_delimiter=secondary_delimiter,
raw_record_column_name=raw_record_column_name,
),
column_infos,
)
)


@beam.typehints.with_input_types(CSVLine)
Expand Down
Loading

0 comments on commit ce2e12b

Please sign in to comment.