In [23]:
import tensorflow as tf
from tensorflow.data import TFRecordDataset, AUTOTUNE
# import tf.data.TFRecordDataset
import jax
from typing import Optional
from waymax.config import DatasetConfig, DataFormat
import jax.numpy as jnp
import functools


In [24]:
path = "./uncompressed_tf_example_training_training_tfexample.tfrecord-00000-of-01000"
myconfig = DatasetConfig(
    path = "./uncompressed_tf_example_training_training_tfexample.tfrecord-00000-of-01000",
    max_num_rg_points=20000,
    data_format=DataFormat.TFRECORD,
)

In [25]:
from waymax.dataloader import womd_utils
def preprocess_serialized_womd_data(
    serialized: bytes, config: DatasetConfig
) -> dict[str, tf.Tensor]:
  """Parses serialized tf example into tf Tensor dict."""
  womd_features = womd_utils.get_features_description(
      include_sdc_paths=config.include_sdc_paths,
      max_num_rg_points=config.max_num_rg_points,
      num_paths=config.num_paths,
      num_points_per_path=config.num_points_per_path,
  )

  deserialized = tf.io.parse_example(serialized, womd_features)
  return preprocess_womd_example(
      deserialized,
      aggregate_timesteps=config.aggregate_timesteps,
      max_num_objects=config.max_num_objects,
  )


def preprocess_womd_example(
    example: dict[str, tf.Tensor],
    aggregate_timesteps: bool,
    max_num_objects: Optional[int] = None,
) -> dict[str, tf.Tensor]:
  """Preprocesses dict of tf tensors, keyed by str."""

  if aggregate_timesteps:
    processed = womd_utils.aggregate_time_tensors(example)
    wrap_yaws = lambda yaws: (yaws + jnp.pi) % (2 * jnp.pi) - jnp.pi
    processed['state/all/bbox_yaw'] = wrap_yaws(processed['state/all/bbox_yaw'])
  else:
    processed = example

  if max_num_objects is not None:
    # TODO check sdc included if it is needed.
    return {
        k: v[:max_num_objects] if k.startswith('state/') else v
        for k, v in processed.items()
    }
  else:
    return processed



In [26]:
files_to_load = [path]
files = tf.data.Dataset.from_tensor_slices(files_to_load)

files = files.shard(jax.process_count(), jax.process_index())



In [27]:
data = files.interleave(tf.data.TFRecordDataset, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True)
# data = tf.data.TFRecordDataset(files)
# myconfig.aggregate_timesteps = False
preprocess_fn = functools.partial(preprocess_serialized_womd_data, config = myconfig)


data = data.map(
    preprocess_fn, num_parallel_calls=AUTOTUNE, deterministic=True
)

# print(type(data))
# print(type(data.take(1)))
# for item in data.take(1):
#     print(len(item.keys()))


In [31]:
from waymax.dataloader import womd_factories

print("hello")
for item in data.take(1):
    # for k in item:
    #     print(k, item[k].shape)
    s = womd_factories.simulator_state_from_womd_dict(item)

hello


In [29]:
print(myconfig.include_sdc_paths, myconfig.max_num_rg_points, myconfig.num_paths,myconfig.num_points_per_path)

print(myconfig.aggregate_timesteps, myconfig.max_num_objects)

False 20000 None None
True None
