Skip to content

Commit

Permalink
[Data] Skip schema call in to_tf if tf.TypeSpec is provided (#4…
Browse files Browse the repository at this point in the history
…2917)

This PR is to skip `Dataset.schema()` call from `Dataset.to_tf()`. `Dataset.schema()` relies on `limit` to early stop execution, and sometimes the stop is not triggered timely so a lot of tasks get executed. This introduced problem to cause memory spilling. In addition, sometimes, it returns `None` (does not work with limit push down), and it breaks followed logic in `to_tf`, which all relies on `schema()` to work.

In this PR:
* Introduce two optional parameters in `to_tf`: `feature_type_spec` and `label_type_spec` (by default they are `None`). So user can set `tf.TypeSpec` explicitly and the `Dataset.schema()` call will be skipped.

Signed-off-by: Cheng Su <scnju13@gmail.com>
  • Loading branch information
c21 committed Feb 8, 2024
1 parent ffc9101 commit c90b476
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
12 changes: 12 additions & 0 deletions python/ray/data/dataset.py
Expand Up @@ -4004,6 +4004,8 @@ def to_tf(
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
feature_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
label_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
# Deprecated
prefetch_blocks: int = 0,
) -> "tf.data.Dataset":
Expand Down Expand Up @@ -4088,6 +4090,14 @@ def to_tf(
therefore ``batch_size`` must also be specified when using local
shuffling.
local_shuffle_seed: The seed to use for the local random shuffle.
feature_type_spec: The `tf.TypeSpec` of `feature_columns`. If there is
only one column, specify a `tf.TypeSpec`. If there are multiple columns,
specify a ``dict`` that maps column names to their `tf.TypeSpec`.
Default is `None` to automatically infer the type of each column.
label_type_spec: The `tf.TypeSpec` of `label_columns`. If there is
only one column, specify a `tf.TypeSpec`. If there are multiple columns,
specify a ``dict`` that maps column names to their `tf.TypeSpec`.
Default is `None` to automatically infer the type of each column.
Returns:
A `TensorFlow Dataset`_ that yields inputs and targets.
Expand All @@ -4107,6 +4117,8 @@ def to_tf(
batch_size=batch_size,
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
feature_type_spec=feature_type_spec,
label_type_spec=label_type_spec,
)

@ConsumptionAPI(pattern="Time complexity:")
Expand Down
28 changes: 18 additions & 10 deletions python/ray/data/iterator.py
Expand Up @@ -676,6 +676,8 @@ def to_tf(
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
feature_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
label_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
# Deprecated.
prefetch_blocks: int = 0,
) -> "tf.data.Dataset":
Expand Down Expand Up @@ -761,6 +763,14 @@ def to_tf(
therefore ``batch_size`` must also be specified when using local
shuffling.
local_shuffle_seed: The seed to use for the local random shuffle.
feature_type_spec: The `tf.TypeSpec` of `feature_columns`. If there is
only one column, specify a `tf.TypeSpec`. If there are multiple columns,
specify a ``dict`` that maps column names to their `tf.TypeSpec`.
Default is `None` to automatically infer the type of each column.
label_type_spec: The `tf.TypeSpec` of `label_columns`. If there is
only one column, specify a `tf.TypeSpec`. If there are multiple columns,
specify a ``dict`` that maps column names to their `tf.TypeSpec`.
Default is `None` to automatically infer the type of each column.
Returns:
A ``tf.data.Dataset`` that yields inputs and targets.
Expand All @@ -776,9 +786,6 @@ def to_tf(
except ImportError:
raise ValueError("tensorflow must be installed!")

schema = self.schema()
valid_columns = schema.names

def validate_column(column: str) -> None:
if column not in valid_columns:
raise ValueError(
Expand All @@ -794,9 +801,6 @@ def validate_columns(columns: Union[str, List]) -> None:
else:
validate_column(columns)

validate_columns(feature_columns)
validate_columns(label_columns)

def convert_batch_to_tensors(
batch: Dict[str, np.ndarray],
*,
Expand Down Expand Up @@ -830,12 +834,16 @@ def generator():
)
yield features, labels

feature_type_spec = get_type_spec(schema, columns=feature_columns)
label_type_spec = get_type_spec(schema, columns=label_columns)
output_signature = (feature_type_spec, label_type_spec)
if feature_type_spec is None or label_type_spec is None:
schema = self.schema()
valid_columns = schema.names
validate_columns(feature_columns)
validate_columns(label_columns)
feature_type_spec = get_type_spec(schema, columns=feature_columns)
label_type_spec = get_type_spec(schema, columns=label_columns)

dataset = tf.data.Dataset.from_generator(
generator, output_signature=output_signature
generator, output_signature=(feature_type_spec, label_type_spec)
)

options = tf.data.Options()
Expand Down
19 changes: 19 additions & 0 deletions python/ray/data/tests/test_tf.py
Expand Up @@ -31,6 +31,25 @@ def test_element_spec_type(self):
assert isinstance(feature_spec, tf.TypeSpec)
assert isinstance(label_spec, tf.TypeSpec)

def test_element_spec_user_provided(self):
ds = ray.data.from_items([{"spam": 0, "ham": 0, "eggs": 0}])

dataset1 = ds.to_tf(feature_columns=["spam", "ham"], label_columns="eggs")
feature_spec, label_spec = dataset1.element_spec
dataset2 = ds.to_tf(
feature_columns=["spam", "ham"],
label_columns="eggs",
feature_type_spec=feature_spec,
label_type_spec=label_spec,
)
feature_output_spec, label_output_spec = dataset2.element_spec
assert isinstance(label_output_spec, tf.TypeSpec)
assert isinstance(feature_output_spec, dict)
assert feature_output_spec.keys() == {"spam", "ham"}
assert all(
isinstance(value, tf.TypeSpec) for value in feature_output_spec.values()
)

def test_element_spec_type_with_multiple_columns(self):
ds = ray.data.from_items([{"spam": 0, "ham": 0, "eggs": 0}])

Expand Down

0 comments on commit c90b476

Please sign in to comment.