In [1]:
from ros_py_types.geometry_msgs import Pose
from bam_gym.envs import custom_spaces
from bam_gym.envs.clients import GenericGymClient
from bam_gym.transport import MockTransport

import gymnasium as gym
from gymnasium import spaces
import numpy as np

%reload_ext autoreload
%autoreload 2


In [2]:
env = GenericGymClient(transport=MockTransport())

In [3]:
print(env.observation_space)
print(env.observation_space.sample())

Dict()
{}


In [4]:
env = GenericGymClient(transport=MockTransport(), n_pose=2)
print(env.observation_space)
for key, value in env.observation_space.items():
    print(key, value)

sample = env.observation_space.sample(mask=env.obs_mask)

print(sample['pose'][0])

pose = Pose.from_dict(sample['pose'][0])
print(pose)
print(pose.to_array())


Dict('pose_names': Sequence(Text(1, 32, charset=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz), stack=False), 'pose': Sequence(Dict('orientation': Dict('x': Box(-3.1415927, 3.1415927, (), float32), 'y': Box(-3.1415927, 3.1415927, (), float32), 'z': Box(-3.1415927, 3.1415927, (), float32)), 'position': Dict('x': Box(-10.0, 10.0, (), float32), 'y': Box(-10.0, 10.0, (), float32), 'z': Box(-10.0, 10.0, (), float32))), stack=False))
pose_names Sequence(Text(1, 32, charset=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz), stack=False)
pose Sequence(Dict('orientation': Dict('x': Box(-3.1415927, 3.1415927, (), float32), 'y': Box(-3.1415927, 3.1415927, (), float32), 'z': Box(-3.1415927, 3.1415927, (), float32)), 'position': Dict('x': Box(-10.0, 10.0, (), float32), 'y': Box(-10.0, 10.0, (), float32), 'z': Box(-10.0, 10.0, (), float32))), stack=False)
{'orientation': {'x': array(-0.5273541, dtype=float32), 'y': array(0.3649438, dtype=float32), 'z': array(-2.7931607

In [5]:
# How can I do observation space validation in a policy?

# I can sample the observation space and check for values...

# For example for blind_policy

sample: dict = env.observation_space.sample(mask=env.obs_mask)

pose_list = sample.get('pose', ())
assert len(pose_list) >= 1
print(len(pose_list))
print(pose_list)

2
({'orientation': {'x': array(-1.7038851, dtype=float32), 'y': array(2.8505397, dtype=float32), 'z': array(1.1059657, dtype=float32)}, 'position': {'x': array(-9.191089, dtype=float32), 'y': array(3.0121968, dtype=float32), 'z': array(-4.7050867, dtype=float32)}}, {'orientation': {'x': array(2.9969578, dtype=float32), 'y': array(-3.0666096, dtype=float32), 'z': array(2.9440293, dtype=float32)}, 'position': {'x': array(7.883803, dtype=float32), 'y': array(7.649813, dtype=float32), 'z': array(-7.037903, dtype=float32)}})


In [6]:
action_space = spaces.Box(low=-1, high=1, shape=None, dtype=np.float32)

action_space.seed(1)
for i in range(10):
    print(action_space.sample())


[0.02364325]
[0.90092736]
[-0.71168077]
[0.8972989]
[-0.37633708]
[-0.1533471]
[0.65540516]
[-0.18160173]
[0.09918737]
[-0.9448818]


In [7]:
action_space = spaces.Sequence(spaces.Box(low=-1, high=1, shape=None, dtype=np.float32))
action_space.sample(mask=(2,None))

(array([0.980985], dtype=float32), array([-0.9163906], dtype=float32))

In [8]:
action_space = spaces.Sequence(spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32))
action_space.sample(mask=(2,None))

(array([ 0.92100596,  0.06086158, -0.9437766 ,  0.0060397 ,  0.84798694],
       dtype=float32),
 array([ 0.85627806, -0.17273515,  0.8708439 , -0.30081898,  0.39203802],
       dtype=float32))

In [9]:
observation_dict = spaces.Dict({})
observation_dict["obs_names"] = spaces.Sequence(spaces.Text(max_length=32))
observation_dict["obs"] = spaces.Sequence(spaces.Box(low=-1, high=1, shape=None, dtype=np.float32))

In [10]:
mask = { "obs_names": (5,None), "obs": (5,None) }
observation_dict.sample(mask=mask)

{'obs_names': ('cWiQY5OCFzjUlp4lRXyGyTrovVhjXok',
  'fTY3nTIDGQr',
  'AHrJjLMbiLpbcbm',
  'z2xEBuMYpl1WZAsGCY',
  'kmIA4N6YelCqjfj'),
 'obs': (array([-0.84815335], dtype=float32),
  array([0.4123818], dtype=float32),
  array([0.4232521], dtype=float32),
  array([0.26903862], dtype=float32),
  array([0.87645984], dtype=float32))}

In [11]:
observation_dict = spaces.Dict({})
observation_dict["obs_names"] = spaces.Sequence(spaces.Text(max_length=32))
observation_dict["obs"] = spaces.Sequence(spaces.Box(low=-1, high=1, shape=(2,2), dtype=np.float32))

In [12]:
mask = { "obs_names": (5,None), "obs": (5,None) }
observation_dict.sample(mask=mask)

{'obs_names': ('9QKeweedClxZdtYpw5QAMwfuou',
  'hfsbpKdvIO2gQ70M',
  'nRtSQx3FlFGyLlaoRIDfv5orev0',
  'KpOlqYkQuluV0vZoRfjIFg',
  'v2EDSqbcV6oWJe'),
 'obs': (array([[ 0.9213963 , -0.65090233],
         [-0.07813599,  0.9216735 ]], dtype=float32),
  array([[-0.00459971,  0.88070905],
         [-0.7253376 ,  0.93278176]], dtype=float32),
  array([[ 0.73217714, -0.1590461 ],
         [-0.64967006, -0.26674056]], dtype=float32),
  array([[ 0.719505  ,  0.99923456],
         [-0.45023882,  0.40345946]], dtype=float32),
  array([[-0.24893403,  0.15110302],
         [ 0.89823145, -0.24536996]], dtype=float32))}

What about neseted sequences?
https://gymnasium.farama.org/api/spaces/composite/#gymnasium.spaces.Sequence

In [13]:
observation_dict = spaces.Dict({})
observation_dict["obs_names"] = spaces.Sequence(spaces.Text(max_length=32))
observation_dict["obs"] = spaces.Sequence(spaces.Box(low=-1, high=1, shape=None, dtype=np.float32))

observation_dict = spaces.Sequence(observation_dict)

In [14]:
mask = (5,{ "obs_names": (5,None), "obs": (5,None) })
observation_dict.sample(mask=mask)

({'obs_names': ('xDA9RG1sveeumpHpf0LAKzKfak',
   'RpBs2GJkcydiepOiwdCHr',
   'IoBEzOlOVWeU',
   'jiqpP2s1IAm8fsPPNdFI',
   'lyY1iuY'),
  'obs': (array([-0.3949288], dtype=float32),
   array([0.1260257], dtype=float32),
   array([-0.05683532], dtype=float32),
   array([0.413885], dtype=float32),
   array([0.1613597], dtype=float32))},
 {'obs_names': ('QoEuDBBlXTZrQjTXrYQ0eFoy66',
   'uM4WG7',
   '9Ty7O82',
   '8MJcS3daJ3GuAvk4ZS6E1g',
   'a9S0MezC87eyrzVdFBWIl'),
  'obs': (array([0.22787406], dtype=float32),
   array([-0.3224399], dtype=float32),
   array([0.34594372], dtype=float32),
   array([0.37780488], dtype=float32),
   array([0.3029198], dtype=float32))},
 {'obs_names': ('qo41dc4FWW',
   'YZrlbTrXumc2',
   'OX',
   'Haqab',
   'iBuw0Kw0hUEDxN4ySx9XBAZ'),
  'obs': (array([0.12652044], dtype=float32),
   array([0.52241015], dtype=float32),
   array([-0.73501694], dtype=float32),
   array([-0.7111185], dtype=float32),
   array([0.00390197], dtype=float32))},
 {'obs_names': ('b1uiHwP

In [39]:
env = GenericGymClient(transport=MockTransport(), num_envs=5, n_pose=3, automask=True)
print(env.observation_space)

sample = env.observation_space.sample()
print(sample)
print(len(sample))

if env.num_envs > 1:
    print(len(sample[0]['pose']))
else:
    print(len(sample['pose']))

<bam_gym.envs.custom_spaces.MaskedSpaceWrapper object at 0x7362a4fc3380>
({'pose_names': ('72zJ3kF', 'hcojNZgAteh90VWnZTYhzAGMf', 'yLRpN'), 'pose': ({'orientation': {'x': array(0.309465, dtype=float32), 'y': array(-0.60658914, dtype=float32), 'z': array(-1.7445222, dtype=float32)}, 'position': {'x': array(-7.8223505, dtype=float32), 'y': array(-6.119114, dtype=float32), 'z': array(-9.900667, dtype=float32)}}, {'orientation': {'x': array(1.7096548, dtype=float32), 'y': array(1.2376128, dtype=float32), 'z': array(2.31895, dtype=float32)}, 'position': {'x': array(6.707916, dtype=float32), 'y': array(-0.33873174, dtype=float32), 'z': array(-9.598553, dtype=float32)}}, {'orientation': {'x': array(-2.9833686, dtype=float32), 'y': array(-0.820378, dtype=float32), 'z': array(2.861606, dtype=float32)}, 'position': {'x': array(-1.4821353, dtype=float32), 'y': array(-5.5466394, dtype=float32), 'z': array(5.3648176, dtype=float32)}})}, {'pose_names': ('hbB06XdKT5nsGT8NvMCdSVa0Gf3dAu', 'i0NVGGjS4Vw