diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 05e5a7e24e5f96..b2f6cd2fa37534 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -772,32 +772,45 @@ def to_tf( A ``tf.data.Dataset`` that yields inputs and targets. """ # noqa: E501 - from ray.air._internal.tensorflow_utils import convert_ndarray_to_tf_tensor + from ray.air._internal.tensorflow_utils import ( + convert_ndarray_to_tf_tensor, + get_type_spec, + ) try: import tensorflow as tf except ImportError: raise ValueError("tensorflow must be installed!") + def validate_column(column: str) -> None: + if column not in valid_columns: + raise ValueError( + f"You specified '{column}' in `feature_columns` or " + f"`label_columns`, but there's no column named '{column}' in the " + f"dataset. Valid column names are: {valid_columns}." + ) + + def validate_columns(columns: Union[str, List]) -> None: + if isinstance(columns, list): + for column in columns: + validate_column(column) + else: + validate_column(columns) + def convert_batch_to_tensors( batch: Dict[str, np.ndarray], *, columns: Union[str, List[str]], - type_spec: Union[tf.TypeSpec, Dict[str, tf.TypeSpec]] = None, + type_spec: Union[tf.TypeSpec, Dict[str, tf.TypeSpec]], ) -> Union[tf.Tensor, Dict[str, tf.Tensor]]: if isinstance(columns, str): return convert_ndarray_to_tf_tensor(batch[columns], type_spec=type_spec) - else: - tensors = {} - for column in columns: - if type_spec is not None: - column_type_spec = type_spec[column] - else: - column_type_spec = None - tensors[column] = convert_ndarray_to_tf_tensor( - batch[column], type_spec=column_type_spec - ) - return tensors + return { + column: convert_ndarray_to_tf_tensor( + batch[column], type_spec=type_spec[column] + ) + for column in columns + } def generator(): for batch in self.iter_batches( @@ -817,13 +830,16 @@ def generator(): ) yield features, labels - if feature_type_spec is not None and label_type_spec is not None: - output_signature = (feature_type_spec, label_type_spec) - else: - output_signature = None + if feature_type_spec is None or label_type_spec is None: + schema = self.schema() + valid_columns = schema.names + validate_columns(feature_columns) + validate_columns(label_columns) + feature_type_spec = get_type_spec(schema, columns=feature_columns) + label_type_spec = get_type_spec(schema, columns=label_columns) dataset = tf.data.Dataset.from_generator( - generator, output_signature=output_signature + generator, output_signature=(feature_type_spec, label_type_spec) ) options = tf.data.Options()