Skip to content

Commit

Permalink
Change to call if type_spec not provided
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng Su <scnju13@gmail.com>
  • Loading branch information
c21 committed Feb 2, 2024
1 parent 99bb9e6 commit ae28239
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions python/ray/data/iterator.py
Expand Up @@ -772,32 +772,45 @@ def to_tf(
A ``tf.data.Dataset`` that yields inputs and targets.
""" # noqa: E501

from ray.air._internal.tensorflow_utils import convert_ndarray_to_tf_tensor
from ray.air._internal.tensorflow_utils import (
convert_ndarray_to_tf_tensor,
get_type_spec,
)

try:
import tensorflow as tf
except ImportError:
raise ValueError("tensorflow must be installed!")

def validate_column(column: str) -> None:
if column not in valid_columns:
raise ValueError(
f"You specified '{column}' in `feature_columns` or "
f"`label_columns`, but there's no column named '{column}' in the "
f"dataset. Valid column names are: {valid_columns}."
)

def validate_columns(columns: Union[str, List]) -> None:
if isinstance(columns, list):
for column in columns:
validate_column(column)
else:
validate_column(columns)

def convert_batch_to_tensors(
batch: Dict[str, np.ndarray],
*,
columns: Union[str, List[str]],
type_spec: Union[tf.TypeSpec, Dict[str, tf.TypeSpec]] = None,
type_spec: Union[tf.TypeSpec, Dict[str, tf.TypeSpec]],
) -> Union[tf.Tensor, Dict[str, tf.Tensor]]:
if isinstance(columns, str):
return convert_ndarray_to_tf_tensor(batch[columns], type_spec=type_spec)
else:
tensors = {}
for column in columns:
if type_spec is not None:
column_type_spec = type_spec[column]
else:
column_type_spec = None
tensors[column] = convert_ndarray_to_tf_tensor(
batch[column], type_spec=column_type_spec
)
return tensors
return {
column: convert_ndarray_to_tf_tensor(
batch[column], type_spec=type_spec[column]
)
for column in columns
}

def generator():
for batch in self.iter_batches(
Expand All @@ -817,13 +830,16 @@ def generator():
)
yield features, labels

if feature_type_spec is not None and label_type_spec is not None:
output_signature = (feature_type_spec, label_type_spec)
else:
output_signature = None
if feature_type_spec is None or label_type_spec is None:
schema = self.schema()
valid_columns = schema.names
validate_columns(feature_columns)
validate_columns(label_columns)
feature_type_spec = get_type_spec(schema, columns=feature_columns)
label_type_spec = get_type_spec(schema, columns=label_columns)

dataset = tf.data.Dataset.from_generator(
generator, output_signature=output_signature
generator, output_signature=(feature_type_spec, label_type_spec)
)

options = tf.data.Options()
Expand Down

0 comments on commit ae28239

Please sign in to comment.