In [2]:
# %env CUDA_VISIBLE_DEVICES=1
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
# %env TF_GPU_ALLOCATOR=cuda_malloc_async
# %env XLA_PYTHON_CLIENT_MEM_FRACTION=.99
%env JAX_PLATFORMS=cpu

# GIT_TRACE=1 pip install pytest git+https://github.com/Physical-Intelligence/openpi.git#subdirectory=packages/openpi-client git+https://github.com/Physical-Intelligence/openpi.git -v

import jax
import flax.nnx

import openpi.models.pi0
import openpi.models.pi0_config
import openpi.models.model


should_copy_weights = False


openpi0_config = openpi.models.pi0_config.Pi0Config(
    action_dim=4, 
    action_horizon=2, 
    max_token_len=8, 
    dtype="bfloat16", 
    paligemma_variant="dummy", 
    action_expert_variant="dummy",
    # pi05=False,
    pi05=True,
)


def make_init_rngs():
    # NOTE due to determinism this cannot be reused
    # `flax.nnx.Rngs` is not pure because every call mutates its internal states
    return flax.nnx.Rngs(jax.random.key(0))

def make_rng():
    return jax.random.key(0)


openpi0 = openpi.models.pi0.Pi0(
    openpi0_config, 
    rngs=make_init_rngs(),
)

observation_openpi0 = openpi0_config.fake_obs()
observation_openpi0 = openpi.models.model.preprocess_observation(None, observation_openpi0, train=False)
action_openpi0 = openpi0_config.fake_act()

action_horizon = action_openpi0.shape[-2]
denoising_num_steps = 10






from robotodo.algos.pi0.nn.pi0 import Pi0




pi0 = Pi0(
    rngs=make_init_rngs(),
    config=Pi0.Config(
        image_shape={
            "height": observation_openpi0.images["base_0_rgb"].shape[-3], 
            "width": observation_openpi0.images["base_0_rgb"].shape[-2], 
            "channels": observation_openpi0.images["base_0_rgb"].shape[-1], 
        },
        numeric_state_size=observation_openpi0.state.shape[-1],
        dtype=openpi0_config.dtype,
        variant="pi05" if openpi0.pi05 else "pi0",
        paligemma_variant=openpi0_config.paligemma_variant,
        action_expert_variant=openpi0_config.action_expert_variant,
    ),
)
# TODO
if should_copy_weights:
    pi0: Pi0 = flax.nnx.merge(flax.nnx.graphdef(pi0), flax.nnx.state(openpi0))


observation = pi0.observation_spec.reshape(sizes={"image": len(observation_openpi0.images)}).empty()
observation["images"] = jax.numpy.asarray(observation["images"])
observation["images.mask"] = jax.numpy.asarray(observation["images.mask"])
for i, (image_key, image) in enumerate(observation_openpi0.images.items()):
    observation["images"] = observation["images"].at[:, i].set(image, mode="fill")
    observation["images.mask"] = observation["images.mask"].at[:, i].set(
        observation_openpi0.image_masks[image_key],
        mode="fill",
    )
observation["tokenized_text"] = observation_openpi0.tokenized_prompt
observation["tokenized_text.mask"] = observation_openpi0.tokenized_prompt_mask
observation["numeric_state"] = observation_openpi0.state


action = pi0.action_spec.empty()
action["numeric_states"] = action_openpi0


for openpi0_embeddings, pi0_embeddedings in zip(
    openpi0.embed_prefix(observation_openpi0), 
    pi0.embed_prefix(observation),
):
    assert jax.numpy.array_equiv(openpi0_embeddings, pi0_embeddedings)


batch_size = observation_openpi0.state.shape[0]
timestep = jax.numpy.broadcast_to(0., batch_size)

for name, openpi0_embeddings, pi0_embeddedings in zip(
    ["tokens", "input_mask", "ar_mask", "adarms_cond"],
    openpi0.embed_suffix(observation_openpi0, action_openpi0, timestep=timestep), 
    pi0.embed_suffix(observation, action, timestep=timestep),
):
    if openpi0_embeddings is None or pi0_embeddedings is None:
        assert openpi0_embeddings == pi0_embeddedings
    else:
        assert jax.numpy.array_equiv(openpi0_embeddings, pi0_embeddedings), name

# TODO
computed_actions_openpi0 = openpi0.sample_actions(
    rng=make_rng(),
    observation=observation_openpi0,
    num_steps=denoising_num_steps,
)
computed_action = pi0.sample_actions(
    rng=make_rng(),
    batch_observation=observation,
    num_timesteps=action_horizon,
    denoising_num_steps=denoising_num_steps,
)
assert jax.numpy.array_equiv(computed_actions_openpi0, computed_action["numeric_states"])


computed_loss_openpi0 = openpi0.compute_loss(
    rng=make_rng(),
    observation=observation_openpi0,
    actions=action_openpi0,
)
computed_loss = pi0.compute_loss(
    rng=make_rng(),
    batch_observation=observation,
    batch_action=action,
)
assert jax.numpy.array_equiv(computed_loss_openpi0, computed_loss)

env: XLA_PYTHON_CLIENT_PREALLOCATE=false
env: JAX_PLATFORMS=cpu


In [None]:
import pytest
import numpy
import openpi.models.tokenizer
from robotodo.algos.pi0.nn.tokenizers import Pi0Tokenizer, Pi05Tokenizer


@pytest.mark.parametrize(
    "prompt", 
    [
        "do something",
        "1234_H_O_T_T_O_G_O",
    ],
)
def test_pi0tokenizer_openpi_compliance(prompt):
    openpi_tokenizer = openpi.models.tokenizer.PaligemmaTokenizer()
    tokens_openpi, masks_openpi = openpi_tokenizer.tokenize(prompt)
    tokens_openpi = tokens_openpi[masks_openpi]

    pi0_tokenizer = Pi0Tokenizer()
    [tokens] = pi0_tokenizer([prompt])

    assert numpy.array_equiv(tokens_openpi, tokens)


@pytest.mark.parametrize(
    "prompt", 
    [
        "do something",
        "1234 H_O_T_T_O_G_O",
    ],
)
@pytest.mark.parametrize(
    "state", 
    [
        [1, 2, 3, 4],
    ],
)
def test_pi05tokenizer_openpi_compliance(prompt, state):
    openpi_tokenizer = openpi.models.tokenizer.PaligemmaTokenizer()
    tokens_openpi, masks_openpi = openpi_tokenizer.tokenize(prompt, state)
    tokens_openpi = tokens_openpi[masks_openpi]

    pi0_tokenizer = Pi05Tokenizer()
    [tokens] = pi0_tokenizer([prompt], [[state]])

    assert numpy.array_equiv(tokens_openpi, tokens)


In [28]:
prompt = "do something"
state = [1, 2, 3, 4]

In [29]:
import openpi.models.tokenizer

openpi_tokenizer = openpi.models.tokenizer.PaligemmaTokenizer()
tokens_openpi, masks_openpi = openpi_tokenizer.tokenize(prompt, state)
tokens_openpi = tokens_openpi[masks_openpi]
tokens_openpi.shape

(29,)

In [31]:
tokens_openpi

array([     2,   7071, 235292,    749,   2775, 235269,   3040, 235292,
       235248, 235284, 235308, 235308, 235248, 235284, 235308, 235308,
       235248, 235284, 235308, 235308, 235248, 235284, 235308, 235308,
       235289,    108,   4022, 235292, 235248])

In [30]:
from robotodo.algos.pi0.nn.tokenizers import Pi05Tokenizer

pi05_tokenizer = Pi05Tokenizer()

import numpy
numpy.asarray(pi05_tokenizer([prompt], [state]))[0].shape


(29,)

In [None]:
numpy.asarray(pi05_tokenizer([prompt], [state]))[0]

array([     2,   7071, 235292,    749,   2775, 235269,   3040, 235292,
       235248, 235284, 235308, 235308, 235248, 235284, 235308, 235308,
       235248, 235284, 235308, 235308, 235248, 235284, 235308, 235308,
       235289,    108,   4022, 235292, 235248])

: 

In [27]:
numpy.testing.assert_array_equal(tokens_openpi, pi05_tokenizer([prompt], [state])[0])

In [2]:
(computed_loss_openpi0, computed_loss)

(Array([[6.5782957, 3.4572563]], dtype=float32),
 Array([[13.951823,  9.148355]], dtype=float32))

In [39]:
flax.nnx.Rngs(jax.random.key(0))

Rngs(
  default=RngStream(
    key=RngKey(
      value=Array((), dtype=key<fry>) overlaying:
      [0 0],
      tag='default'
    ),
    count=RngCount(
      value=Array(0, dtype=uint32),
      tag='default'
    )
  )
)

In [22]:
(computed_actions_openpi0, computed_action["numeric_states"])

(Array([[[ 2.4330268 ,  2.4784737 , -1.1782779 , -0.27594906],
         [ 0.3605225 ,  0.1505478 , -0.54230285, -0.14090279]]],      dtype=float32),
 Array([[[ 0.7458704 ,  1.2616075 , -0.67275417,  0.9028505 ],
         [-0.734951  , -1.5478942 , -0.8076718 ,  0.847718  ]]],      dtype=float32))

In [3]:
(computed_actions_openpi0, computed_action["numeric_states"])


(Array([[[ 2.4330268 ,  2.4784737 , -1.1782779 , -0.27594906],
         [ 0.3605225 ,  0.1505478 , -0.54230285, -0.14090279]]],      dtype=float32),
 Array([[[ 1.0266508 ,  1.6777757 , -0.8130156 ,  0.54437053],
         [-0.13069725, -2.2352436 , -0.9091571 ,  1.016725  ]]],      dtype=float32))

In [23]:
(computed_loss_openpi0, computed_loss)

(Array([[6.5782957, 3.4572563]], dtype=float32),
 Array([[10.573036 ,  7.7422376]], dtype=float32))

In [17]:
openpi0_config.fake_act().shape

(1, 2, 4)

In [2]:
(computed_loss_openpi0, computed_loss)

(Array([[6.5782957, 3.4572563]], dtype=float32),
 Array([[7.173269]], dtype=float32))

In [18]:
action_horizon

2

In [44]:
action_openpi0.shape

(1, 2, 4)

In [4]:
(computed_actions_openpi0, computed_action["numeric_states"])

(Array([[[ 2.4330268 ,  2.4784737 , -1.1782779 , -0.27594906],
         [ 0.3605225 ,  0.1505478 , -0.54230285, -0.14090279]]],      dtype=float32),
 Array([[[ 0.7458704 ,  1.2616075 , -0.67275417,  0.9028505 ],
         [-0.734951  , -1.5478942 , -0.8076718 ,  0.847718  ]]],      dtype=float32))

In [36]:
computed_loss_openpi0, computed_loss

(Array([[6.5782957, 3.4572563]], dtype=float32),
 Array([[7.173269]], dtype=float32))

In [32]:
observation["images"].shape

(1, 3, 224, 224, 3)

In [33]:
observation_openpi0.images

{'base_0_rgb': Array([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          ...,
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],
 
         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          ...,
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],
 
         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          ...,
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],
 
         ...,
 
         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          ...,
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],
 
         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          ...,
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],
 
         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          ...,
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]]

In [21]:
observation["images"].shape

(1, 1, 224, 224, 3)

In [24]:
for openpi0_embeddings, pi0_embeddedings in zip(openpi0.embed_prefix(observation_openpi0), pi0.embed_prefix(observation)):
    assert jax.numpy.array_equiv(openpi0_embeddings, pi0_embeddedings)

In [16]:
batch_size = observation_openpi0.state.shape[0]

timestep = jax.numpy.broadcast_to(0., batch_size)

for name, openpi0_embeddings, pi0_embeddedings in zip(
    ["tokens", "input_mask", "ar_mask", "adarms_cond"],
    openpi0.embed_suffix(observation_openpi0, action_openpi0, timestep=timestep), 
    pi0.embed_suffix(observation, action, timestep=timestep),
):
    if openpi0_embeddings is None or pi0_embeddedings is None:
        assert openpi0_embeddings == pi0_embeddedings
    else:
        assert jax.numpy.array_equiv(openpi0_embeddings, pi0_embeddedings), name

In [13]:
openpi0_embeddings == pi0_embeddedings

True

In [11]:
jax.numpy.array_equiv(openpi0_embeddings, pi0_embeddedings)

  return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)


Array(False, dtype=bool)

In [10]:
(openpi0_embeddings, pi0_embeddedings)

(None, None)

In [4]:
import flax.nnx
flax.nnx.display(pi0)
flax.nnx.display(openpi0)

In [29]:
openpi0

Pi0(action_dim=4, action_horizon=2, max_token_len=8)

In [28]:
pi0

Pi0(
  PaliGemma=Dict(
    llm=ToNNX(
      module=Module(
          # attributes
          configs = [Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16, lora_configs={}), Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16, lora_configs={})]
          embed_dtype = 'bfloat16'
          dropout = 0.0
          dropout_bdims = ()
          adarms = False
      ),
      rngs=None,
      linen_attributes=('final_norm', 'final_norm_1', 'layers', 'embedder'),
      embedder={'input_embedding': Param(
        value=Array(shape=(257152, 64), dtype=float32)
      )},
      final_norm={'scale': Param(
        value=Array(shape=(64,), dtype=float32)
      )},
      final_norm_1={'scale': Param(
        value=Array(shape=(64,), dtype=float32)
      )},
      layers={'attn': {'attn_vec_einsum': {'w': Param(
        value=Array(shape=(4, 8, 16, 64), dtype=float32)
      )}, 'attn_vec_einsum_1': {'w': Param(
        value=Array(shape=(4, 8, 1

In [26]:
openpi0_embeddings, pi0_embeddedings

(Array([[[ 1.79053843e-01,  4.74933982e-02,  1.45280170e+00,
          -1.13485014e+00, -6.58169508e-01,  9.03256595e-01,
           1.00755048e+00,  1.01126623e+00, -4.22463298e-01,
           6.74940407e-01,  4.67697948e-01,  5.17602980e-01,
          -1.05136085e+00,  1.71915829e-01, -1.15855324e+00,
           2.77939975e-01,  5.37473738e-01, -1.62840128e-01,
           1.08238804e+00,  1.06681776e+00, -5.32237113e-01,
           5.46865463e-01, -1.72765642e-01,  7.39489377e-01,
           4.88460690e-01,  1.12517476e+00,  6.18264079e-01,
           9.53357816e-01, -8.01788568e-01, -2.13326764e+00,
          -2.14591002e+00, -1.48850334e+00, -1.54589212e+00,
          -1.34846640e+00,  2.03086436e-02,  8.89516652e-01,
          -5.90233445e-01,  3.14412490e-02,  2.23233342e+00,
           2.28826076e-01, -2.03966570e+00, -6.59065366e-01,
           4.03764784e-01,  8.55188847e-01,  2.00244403e+00,
           9.86532092e-01,  5.79643488e-01, -1.33064306e+00,
           7.15444803e-0

In [15]:
action["numeric_states"].shape

(1, 1, 4)

In [12]:
openpi0.action_horizon

2

In [11]:
name


'tokens'

(Array([[[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [-0.146484, -0.110352, 0.130859, ..., -0.0805664, 0.150391,
          -0.0383301],
         [-0.146484, -0.110352, 0.130859, ..., -0.0805664, 0.150391,
          -0.0383301],
         [-0.146484, -0.110352, 0.130859, ..., -0.0805664, 0.150391,
          -0.0383301]]], dtype=bfloat16),
 Array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,
          Tru

In [15]:
observation

{'images': Array([[[[[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],
 
          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],
 
          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],
 
          ...,
 
          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],
 
          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],
 
          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [

In [14]:
(computed_loss_openpi0, computed_loss)

(Array([[6.5782957, 3.4572563]], dtype=float32),
 Array([[7.1274056]], dtype=float32))

In [8]:
observation

{'images': Array([[[[[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],
 
          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],
 
          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],
 
          ...,
 
          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],
 
          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],
 
          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.],
           ...,
           [1., 1., 1.],
           [

In [11]:
computed_actions_openpi0


Array([[[ 2.4330268 ,  2.4784737 , -1.1782779 , -0.27594906],
        [ 0.3605225 ,  0.1505478 , -0.54230285, -0.14090279]]],      dtype=float32)

In [12]:
computed_action["numeric_states"]

Array([[[ 0.7606724 ,  1.4389639 , -0.7811883 ,  0.8792597 ],
        [-0.7339352 , -1.4569051 , -0.87182415,  0.8626354 ]]],      dtype=float32)

In [3]:
observation

{'images': array([[[[[6.8697638e+21, 4.5631883e-41, 6.8697638e+21],
           [4.5631883e-41, 7.5960469e+03, 0.0000000e+00],
           [7.5960469e+03, 0.0000000e+00, 3.0734703e+32],
           ...,
           [1.0950298e+27, 2.1408451e-10, 7.3604004e+22],
           [4.7399521e+30, 1.6041009e-10, 2.8952667e+32],
           [1.7699258e+31, 1.3493954e+37, 1.3888643e-16]],
 
          [[1.6677464e+19, 1.6594346e-07, 1.8917655e+23],
           [9.8717143e+17, 4.4652594e+30, 7.2708401e+31],
           [2.4512243e-09, 3.0981867e+27, 1.6635618e+22],
           ...,
           [1.1040178e-05, 3.3805180e-18, 3.0353856e+32],
           [1.6529180e+19, 7.0366745e+22, 1.5136596e+37],
           [1.3888643e-16, 2.7944582e+20, 4.0329896e-11]],
 
          [[1.5434927e-19, 5.7072828e-08, 8.1225226e+17],
           [2.1445974e-19, 2.7944582e+20, 7.2133025e+22],
           [3.9336976e-14, 1.3888643e-16, 4.3059060e+21],
           ...,
           [2.5377398e-09, 2.5184978e-12, 7.6797164e+28],
        

In [1]:
# %env CUDA_VISIBLE_DEVICES=1
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
# %env TF_GPU_ALLOCATOR=cuda_malloc_async
# %env XLA_PYTHON_CLIENT_MEM_FRACTION=.99
%env JAX_PLATFORMS=cpu

from robotodo.algos.pi0.nn.pi0 import Pi0
import jax
import flax.nnx

# mesh = jax.sharding.Mesh(jax.devices(), ("x",))
# sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

pi0 = Pi0(flax.nnx.Rngs(0))

key = jax.random.key(0)
observation = pi0.observation_spec.random()
action = pi0.action_spec.random()
timestep = jax.numpy.asarray([0])


sample_actions = pi0.compile_jit(
    pi0.sample_actions,
    static_argnames=("num_timesteps", "denoising_num_steps"),
)

env: XLA_PYTHON_CLIENT_PREALLOCATE=false
env: JAX_PLATFORMS=cpu


In [5]:
observation

{'images': array([[[[[ 3.11678737e-01,  8.80481899e-01, -2.67851025e-01],
           [-9.40250039e-01, -7.99365938e-01,  9.49142396e-01],
           [ 5.68615198e-01, -8.51742685e-01, -3.94750446e-01],
           ...,
           [-3.53965968e-01,  9.48632300e-01, -6.49480462e-01],
           [-8.12962174e-01,  4.13493104e-02, -2.79475242e-01],
           [-9.05155540e-01,  9.15395677e-01, -8.14229250e-01]],
 
          [[-7.48462856e-01, -6.51410997e-01,  3.76508534e-01],
           [-6.12657130e-01, -4.82654393e-01, -9.18885469e-01],
           [ 7.86623478e-01, -5.81557453e-01, -1.04843117e-01],
           ...,
           [ 7.53446281e-01,  3.23209554e-01,  1.61390826e-02],
           [-5.44424176e-01, -7.76082754e-01, -9.01604116e-01],
           [ 7.18752682e-01, -4.51095849e-01, -2.52806008e-01]],
 
          [[-6.49151653e-02, -1.31368097e-02, -7.62562156e-01],
           [-1.43344685e-01, -9.11389709e-01, -6.08043551e-01],
           [-4.58992273e-01, -2.90953666e-01,  2.9010081

In [4]:
sample_actions(key, observation)

{'numeric_states': Array([[[ 1.209037  ,  1.3700281 , -1.206918  ,  0.4093421 ,
          -1.5287708 , -0.36389875,  0.93598616,  0.2658462 ,
           0.3587545 , -1.2649344 ,  1.3237348 , -1.8578876 ,
           0.21691251, -0.39009964,  2.2527442 ,  1.9898949 ,
           0.21479926,  0.06678981,  0.7127061 , -1.8954036 ,
          -1.2006513 ,  2.333873  ,  1.8122951 ,  1.5808662 ,
           1.8202456 ,  0.76633155,  0.7927439 ,  0.48942977,
           1.0296096 , -0.5523396 , -2.9859571 ,  1.2749529 ]]],      dtype=float32)}