Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Skip schema call in to_tf if tf.TypeSpec is provided #42917

Merged
merged 5 commits into from Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/ray/data/dataset.py
Expand Up @@ -4004,6 +4004,8 @@ def to_tf(
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
feature_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
label_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
# Deprecated
prefetch_blocks: int = 0,
) -> "tf.data.Dataset":
Expand Down Expand Up @@ -4088,6 +4090,14 @@ def to_tf(
therefore ``batch_size`` must also be specified when using local
shuffling.
local_shuffle_seed: The seed to use for the local random shuffle.
feature_type_spec: The `tf.TypeSpec` of `feature_columns`. If there is
only one column, specify a `tf.TypeSpec`. If there are multiple columns,
specify a ``dict`` that maps column names to their `tf.TypeSpec`.
Default is `None` to automatically infer the type of each column.
label_type_spec: The `tf.TypeSpec` of `label_columns`. If there is
only one column, specify a `tf.TypeSpec`. If there are multiple columns,
specify a ``dict`` that maps column names to their `tf.TypeSpec`.
Default is `None` to automatically infer the type of each column.

Returns:
A `TensorFlow Dataset`_ that yields inputs and targets.
Expand All @@ -4107,6 +4117,8 @@ def to_tf(
batch_size=batch_size,
local_shuffle_buffer_size=local_shuffle_buffer_size,
local_shuffle_seed=local_shuffle_seed,
feature_type_spec=feature_type_spec,
label_type_spec=label_type_spec,
)

@ConsumptionAPI(pattern="Time complexity:")
Expand Down
28 changes: 18 additions & 10 deletions python/ray/data/iterator.py
Expand Up @@ -676,6 +676,8 @@ def to_tf(
drop_last: bool = False,
local_shuffle_buffer_size: Optional[int] = None,
local_shuffle_seed: Optional[int] = None,
feature_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
label_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
# Deprecated.
prefetch_blocks: int = 0,
) -> "tf.data.Dataset":
Expand Down Expand Up @@ -761,6 +763,14 @@ def to_tf(
therefore ``batch_size`` must also be specified when using local
shuffling.
local_shuffle_seed: The seed to use for the local random shuffle.
feature_type_spec: The `tf.TypeSpec` of `feature_columns`. If there is
only one column, specify a `tf.TypeSpec`. If there are multiple columns,
specify a ``dict`` that maps column names to their `tf.TypeSpec`.
Default is `None` to automatically infer the type of each column.
label_type_spec: The `tf.TypeSpec` of `label_columns`. If there is
only one column, specify a `tf.TypeSpec`. If there are multiple columns,
specify a ``dict`` that maps column names to their `tf.TypeSpec`.
Default is `None` to automatically infer the type of each column.

Returns:
A ``tf.data.Dataset`` that yields inputs and targets.
Expand All @@ -776,9 +786,6 @@ def to_tf(
except ImportError:
raise ValueError("tensorflow must be installed!")

schema = self.schema()
valid_columns = schema.names

def validate_column(column: str) -> None:
if column not in valid_columns:
raise ValueError(
Expand All @@ -794,9 +801,6 @@ def validate_columns(columns: Union[str, List]) -> None:
else:
validate_column(columns)

validate_columns(feature_columns)
validate_columns(label_columns)

def convert_batch_to_tensors(
batch: Dict[str, np.ndarray],
*,
Expand Down Expand Up @@ -830,12 +834,16 @@ def generator():
)
yield features, labels

feature_type_spec = get_type_spec(schema, columns=feature_columns)
label_type_spec = get_type_spec(schema, columns=label_columns)
output_signature = (feature_type_spec, label_type_spec)
if feature_type_spec is None or label_type_spec is None:
schema = self.schema()
valid_columns = schema.names
validate_columns(feature_columns)
validate_columns(label_columns)
feature_type_spec = get_type_spec(schema, columns=feature_columns)
label_type_spec = get_type_spec(schema, columns=label_columns)

dataset = tf.data.Dataset.from_generator(
generator, output_signature=output_signature
generator, output_signature=(feature_type_spec, label_type_spec)
)

options = tf.data.Options()
Expand Down
19 changes: 19 additions & 0 deletions python/ray/data/tests/test_tf.py
Expand Up @@ -31,6 +31,25 @@ def test_element_spec_type(self):
assert isinstance(feature_spec, tf.TypeSpec)
assert isinstance(label_spec, tf.TypeSpec)

def test_element_spec_user_provided(self):
ds = ray.data.from_items([{"spam": 0, "ham": 0, "eggs": 0}])

dataset1 = ds.to_tf(feature_columns=["spam", "ham"], label_columns="eggs")
feature_spec, label_spec = dataset1.element_spec
dataset2 = ds.to_tf(
feature_columns=["spam", "ham"],
label_columns="eggs",
feature_type_spec=feature_spec,
label_type_spec=label_spec,
)
feature_output_spec, label_output_spec = dataset2.element_spec
assert isinstance(label_output_spec, tf.TypeSpec)
assert isinstance(feature_output_spec, dict)
assert feature_output_spec.keys() == {"spam", "ham"}
assert all(
isinstance(value, tf.TypeSpec) for value in feature_output_spec.values()
)

def test_element_spec_type_with_multiple_columns(self):
ds = ray.data.from_items([{"spam": 0, "ham": 0, "eggs": 0}])

Expand Down