In [1]:
# %env XLA_PYTHON_CLIENT_MEM_FRACTION=0.4
%env CUDA_VISIBLE_DEVICES=0
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

env: CUDA_VISIBLE_DEVICES=0
env: XLA_PYTHON_CLIENT_PREALLOCATE=false
env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [None]:
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import RandomSampler, ReplayBuffer
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import RewardSum, TransformedEnv, set_exploration_type, ExplorationType
from torchrl.modules import (
	AdditiveGaussianModule,
	MultiAgentMLP,
	ProbabilisticActor,
	TanhDelta,
)
from torchrl.objectives import DDPGLoss, SoftUpdate, ValueEstimators
from tqdm.auto import tqdm

torch.manual_seed(911)

<torch._C.Generator at 0x7fbf9799e510>

In [None]:
model_name = "ISAC"

In [3]:
device = torch.device("cuda:0")

frames_per_batch = 27_000
n_iters = 100
total_frames = frames_per_batch * n_iters

# Replay buffer
memory_size = 1_000_000

n_optimiser_steps = 200
train_batch_size = 500
lr = 3e-4
max_grad_norm = 1.0

gamma = 0.99
polyak_tau = 0.005

In [4]:
from rl_env.core.environment import Env
from rl_env.torchrl.torchrl_wrapper import MyEnvWrapper

In [5]:
width = 10
height = 10
obstacle_density = 0.0
num_agents = 5
grain_factor = 6

env = Env(
		width=width,
		height=height,
		obstacle_density=obstacle_density,
		num_agents=num_agents,
		grain_factor=grain_factor,
		contact_force=500,
		contact_margin=1e-3,
		dt=0.01,
		max_steps=900,
		frameskip=7,
		max_obs=8,
)

In [6]:
num_envs = frames_per_batch // env.max_steps
print("num_envs =", num_envs)

env = MyEnvWrapper(env, device=device, batch_size=[num_envs])
env.set_seed(0)

num_envs = 30


795726461

In [7]:
print("action_spec:\n", env.full_action_spec, "\n")
print("reward_spec:\n", env.full_reward_spec, "\n")
print("done_spec:\n", env.full_done_spec, "\n")
print("observation_spec:\n", env.observation_spec, "\n")

action_spec:
 Composite(
    agents: Composite(
        action: BoundedContinuous(
            shape=torch.Size([30, 5, 2]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([30, 5, 2]), device=cuda:0, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([30, 5, 2]), device=cuda:0, dtype=torch.float32, contiguous=True)),
            device=cuda:0,
            dtype=torch.float32,
            domain=continuous),
        device=cuda:0,
        shape=torch.Size([30, 5])),
    device=cuda:0,
    shape=torch.Size([30])) 

reward_spec:
 Composite(
    agents: Composite(
        reward: BoundedContinuous(
            shape=torch.Size([30, 5, 1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([30, 5, 1]), device=cuda:0, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([30, 5, 1]), device=cuda:0, dtype=torch.float32, contiguous=True)),
            device=cuda:0,
         

In [8]:
print("action_keys:", env.action_keys)
print("reward_keys:", env.reward_keys)
print("done_keys:", env.done_keys)

action_keys: [('agents', 'action')]
reward_keys: [('agents', 'reward')]
done_keys: ['done', 'terminated']


In [9]:
env = TransformedEnv(
	env,
	RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
)

In [10]:
policy_net = MultiAgentMLP(
	n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
	n_agent_outputs=env.action_spec.shape[-1],
	n_agents=num_agents,
	device=device,
	depth=2,
	num_cells=256,
	activation_class=torch.nn.Tanh,
	share_params=True, # Can be changed
	centralized=False,
)



In [11]:
policy_module = TensorDictModule(
	policy_net,
	in_keys=[("agents", "observation")],
	out_keys=[("agents", "param")],
)

In [12]:
policy = ProbabilisticActor(
	module=policy_module,
	spec=env.action_spec,
	in_keys=[("agents", "param")],
	out_keys=[env.action_key],
	distribution_class=TanhDelta,
	distribution_kwargs={
		"low": env.full_action_spec_unbatched["agents", "action"].space.low,
		"high": env.full_action_spec_unbatched["agents", "action"].space.high,
	},
	return_log_prob=False,
)

In [13]:
exploration_policy = TensorDictSequential(
	policy,
	AdditiveGaussianModule(
		spec=policy.spec,
		annealing_num_steps=total_frames // 2,
		action_key=("agents", "action"),
		sigma_init=0.9,
		sigma_end=0.1,
	),
)

In [14]:
cat_module = TensorDictModule(
	lambda obs, action: torch.cat([obs, action], dim=-1),
	in_keys=[("agents", "observation"), ("agents", "action")],
	out_keys=[("agents", "obs_action")],
)

critic_net = MultiAgentMLP(
	n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1] + env.full_action_spec["agents", "action"].shape[-1],
	n_agent_outputs=1,
	n_agents=num_agents,
	device=device,
	depth=2,
	num_cells=256,
	activation_class=torch.nn.Tanh,
	share_params=True, # can be changed
	centralized=True, # True for maddpg and false for iddpg
)

critic_module = TensorDictModule(
	module=critic_net,
	in_keys=[("agents", "obs_action")],
	out_keys=[("agents", "state_action_value")],
)

critic = TensorDictSequential(
	cat_module, critic_module,
)

In [15]:
collector = SyncDataCollector(
	env,
	exploration_policy,
	device=device,
	storing_device=device,
	frames_per_batch=frames_per_batch,
	total_frames=total_frames,
)

In [16]:
replay_buffer = ReplayBuffer(
	storage=LazyTensorStorage(
		memory_size, device=device
	),
	sampler=RandomSampler(),
	batch_size=train_batch_size,
)

In [17]:
loss_module = DDPGLoss(
	actor_network=policy,
	value_network=critic,
	delay_value=True,
	loss_function="l2",
)
loss_module.set_keys(
	state_action_value=("agents", "state_action_value"),
	reward=env.reward_key,
	done=("agents", "done"),
	terminated=("agents", "terminated"),
)
loss_module.make_value_estimator(ValueEstimators.TD0, gamma=gamma)

target_updater = SoftUpdate(loss_module, tau=polyak_tau)

optimisers = {
	"loss_actor": torch.optim.Adam(loss_module.actor_network_params.flatten_keys().values(), lr=lr),
	"loss_value": torch.optim.Adam(loss_module.value_network_params.flatten_keys().values(), lr=lr),
}

In [18]:
def process_batch(batch: TensorDictBase) -> TensorDictBase:
	"""
	If the `(group, "terminated")` and `(group, "done")` keys are not present, create them by expanding
	`"terminated"` and `"done"`.
	This is needed to present them with the same shape as the reward to the loss.
	"""
	keys = list(batch.keys(True, True))
	group_shape = batch.get_item_shape("agents")
	nested_done_key = ("next", "agents", "done")
	nested_terminated_key = ("next", "agents", "terminated")
	if nested_done_key not in keys:
		batch.set(
			nested_done_key,
			batch.get(("next", "done")).unsqueeze(-1).expand((*group_shape, 1)),
		)
	if nested_terminated_key not in keys:
		batch.set(
			nested_terminated_key,
			batch.get(("next", "terminated"))
			.unsqueeze(-1)
			.expand((*group_shape, 1)),
		)
	return batch

In [19]:
import wandb

In [None]:
pbar = tqdm(total=n_iters, desc="episode_reward_mean = 0")

wandb.init(
	project="myenv_xppo",
	config={
		"width": width,
		"height": height,
		"obstacle_density": obstacle_density,
		"num_agents": num_agents,
		"grain_factor": grain_factor,
		"frames_per_batch": frames_per_batch,
		"n_iters": n_iters,
		"total_frames": total_frames,
		"n_optimiser_steps": n_optimiser_steps,
		"train_batch_size": train_batch_size,
		"lr": lr,
		"max_grad_norm": max_grad_norm,
		"gamma": gamma,
		"polyak_tau": polyak_tau,
	},
	name=f"MADDPG1_collision_2.0_goal_dist_log_0.01"
)

# Training/collection iterations
for iteration, batch in enumerate(collector):
	current_frames = batch.numel()
	batch = process_batch(batch)  # Util to expand done keys if needed

	data_view = batch.reshape(-1) # This just affects the leading dimensions in batch_size of the tensordict
	replay_buffer.extend(data_view)

	for _ in range(n_optimiser_steps):
		subdata = replay_buffer.sample()
		loss_vals = loss_module(subdata)

		grad_norms = {}

		for loss_name in ["loss_actor", "loss_value"]:
			loss = loss_vals[loss_name]
			optimiser = optimisers[loss_name]

			loss.backward()

			# Optional
			params = optimiser.param_groups[0]["params"]
			grad_norm = torch.nn.utils.clip_grad_norm_(params, max_grad_norm)
			grad_norms[loss_name] = grad_norm

			optimiser.step()
			optimiser.zero_grad()

		wandb.log({
			"loss_actor": loss_vals["loss_actor"].item(),
			"loss_value": loss_vals["loss_value"].item(),
			"grad_norm_actor": grad_norms["loss_actor"].item(),
			"grad_norm_value": grad_norms["loss_value"].item(),
		})

		# Soft-update the target network
		target_updater.step()

	# Exploration sigma anneal update
	exploration_policy[-1].step(current_frames)

	done = batch.get(("next", "done"))[:, :, 0]
	episode_reward_mean = batch.get(("next", "agents", "episode_reward"))[done].mean().item()

	wandb.log({
		"ep_rew_mean": episode_reward_mean,
		"iter": iteration,
	})

	pbar.set_description(f"episode_reward_mean = {episode_reward_mean :.2f}", refresh=False)
	pbar.update()

wandb.finish()

episode_reward_mean = 0:   0%|          | 0/100 [00:00<?, ?it/s]

[34m[1mwandb[0m: Currently logged in as: [33mapshenitsyn[0m ([33mapshenitsyn-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value.contiguous()))


0,1
ep_rew_mean,▇▁▃▄████████████████████████████████████
grad_norm_actor,█▃▂▁▁▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
grad_norm_value,▆▂▂▄▅█▅█▇▅▄▇▆▅▆█▇▅▅▇▄▅▂▁▂▂▂▂▂▂▂▂▁▃▂▁▃▂▃▂
iter,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█
loss_actor,▅▅██▇▆▇▆▇▅▅▅▅▅▄▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_value,▆█▆██▄▄▄▄▄▄▄▃▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
ep_rew_mean,6.90189
grad_norm_actor,0.01787
grad_norm_value,0.12523
iter,99.0
loss_actor,-1.02232
loss_value,0.02058


In [21]:
import dataclasses

хорошие сиды для теста: 

4 (1)

5 (0)

In [None]:
def get_state_from_envs(state, env_id):
	state_data = {field.name: getattr(state, field.name)[env_id] for field in dataclasses.fields(state)}
	return type(state)(**state_data)

def rendering_callback(env, td):
	env.state_seq.append(get_state_from_envs(env._state, 0))


viz_env = Env(
		width=width,
		height=height,
		obstacle_density=obstacle_density,
		num_agents=num_agents,
		grain_factor=grain_factor,
		contact_force=500,
		contact_margin=1e-3,
		dt=0.01,
		max_steps=env.max_steps,
		frameskip=7,
		max_obs=8,
)

viz_env = MyEnvWrapper(viz_env, device=device, batch_size=[2])
viz_env.set_seed(5)

viz_env = TransformedEnv(
	viz_env,
	RewardSum(in_keys=[viz_env.reward_key], out_keys=[("agents", "episode_reward")]),
)


viz_env.state_seq = []

with torch.no_grad():
   with set_exploration_type(ExplorationType.DETERMINISTIC):
	   out = viz_env.rollout(
		   auto_reset=True,
		   max_steps=viz_env.max_steps + 1,
		   policy=exploration_policy,
		   callback=rendering_callback,
		   auto_cast_to_device=True,
		   break_when_any_done=True,
	   )

len(viz_env.state_seq), viz_env.state_seq[-1].step

  return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value.contiguous()))


(900, Array(900, dtype=int32, weak_type=True))

In [25]:
from rl_env.render.renderer import SVG_Visualizer

SVG_Visualizer(viz_env._env, viz_env.state_seq).save_svg("maddpg_example.svg")

In [None]:
out["next", "agents", "observation"][1, :, :, 2:].max()

tensor(0.7285, device='cuda:0')

In [None]:
import jax
import jax.numpy as jnp


def rewards(self, agent_pos, landmark_pos, goal_dist):
	"""Assign rewards for all agents"""

	objects = jnp.vstack((agent_pos, landmark_pos))

	distances = jnp.linalg.norm(agent_pos[:, None, :] - objects[None, :, :], axis=-1)

	nearest_dists, nearest_ids = jax.lax.top_k(-distances, 2) # (num_agents, 2)

	# remove zeros (nearest is the agent itself) -> (num_agents)
	nearest_ids = nearest_ids[:, 1]
	nearest_dists = -nearest_dists[:, 1]

	effective_rad = jnp.where(nearest_ids < self.num_agents, 2 * self.agent_rad, self.agent_rad + self.landmark_rad)

	collision = nearest_dists < (effective_rad * 1.05)

	on_goal = goal_dist < self.goal_rad

	# r = 10.0 * on_goal.astype(jnp.float32) - 0.001 * goal_dist - 1 * collision.astype(jnp.float32)
	r = collision.astype(jnp.float32)
	# r = on_goal.astype(jnp.float32)
	# r = 1.0 * on_goal.astype(jnp.float32) - 0.5 * collision.astype(jnp.float32)
	# r = 1.0 * on_goal.astype(jnp.float32) - 0.2 * collision.astype(jnp.float32)
	return r.reshape(-1, 1)

In [None]:
rewards_seq = []
for state_ in viz_env.state_seq:
	goal_dist = jnp.linalg.norm(state_.agent_pos - state_.goal_pos, axis=-1)
	rewards_seq.append(rewards(viz_env._env, state_.agent_pos, state_.landmark_pos, goal_dist))

In [31]:
jnp.stack(rewards_seq).nonzero()

(Array([58, 58, 59, 59, 60, 60, 61, 61, 62, 62, 63, 63, 64, 64, 65, 65, 66,
        66, 67, 67, 68, 68, 69, 69, 70, 70, 71, 71, 72, 72, 73, 73, 74, 74,
        75, 75, 76, 76, 77, 77, 78, 78, 79, 79, 80, 80], dtype=int32),
 Array([0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4,
        0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4,
        0, 4], dtype=int32),
 Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0], dtype=int32))

In [33]:
rewards_seq[80]

Array([[1.],
       [0.],
       [0.],
       [0.],
       [1.]], dtype=float32)

In [47]:
len(viz_env.state_seq)

156

In [43]:
out["next", "agents", "observation"].shape

torch.Size([2, 157, 5, 18])

In [44]:
out["next", "agents", "observation"][0, 27, 0, :] * env._env.window

tensor([ 0.4722, -0.6458, -0.5506, -0.0432, -0.0594,  0.1023,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000], device='cuda:0')

In [29]:
jnp.stack(rewards_seq).nonzero()

(Array([146, 146, 147, 147, 148, 148, 149, 149, 150, 150, 151, 151, 152,
        152, 153, 153, 154, 154, 155, 155, 156, 156, 157, 157, 158, 158,
        159, 159, 160, 160, 161, 161, 162, 162, 163, 163, 164, 164, 165,
        165, 166, 166, 167, 167, 168, 168, 169, 169, 170, 170, 171, 171],      dtype=int32),
 Array([1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4,
        1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4,
        1, 4, 1, 4, 1, 4, 1, 4], dtype=int32),
 Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))

In [31]:
env.state_seq[146]

State(agent_pos=Array([[-0.24154969, -1.4732722 ],
       [ 0.29230696, -0.15225264],
       [-0.6070331 , -0.18168508],
       [-1.6136379 ,  1.6489443 ],
       [ 0.35022047, -0.39659575]], dtype=float32), agent_vel=Array([[ 1.16601214e-02,  3.88432704e-02],
       [ 2.21877787e-02, -1.51411280e-01],
       [ 1.05127799e-06,  1.34114870e-08],
       [-7.12279528e-02, -3.31292972e-02],
       [-1.03568904e-01,  9.79412422e-02]], dtype=float32), goal_pos=Array([[-0.5999999 , -1.8       ],
       [ 1.        , -1.8       ],
       [-0.5999999 , -0.19999999],
       [-1.4       ,  1.8       ],
       [-1.        ,  1.4000001 ]], dtype=float32), landmark_pos=Array([[-2.        ,  2.        ],
       [-1.9200001 ,  2.        ],
       [-1.8399999 ,  2.        ],
       [-1.76      ,  2.        ],
       [-1.6800001 ,  2.        ],
       [-1.5999999 ,  2.        ],
       [-1.52      ,  2.        ],
       [-1.44      ,  2.        ],
       [-1.3600001 ,  2.        ],
       [-1.28      , 

In [30]:
rewards_seq[146]

Array([[0.],
       [1.],
       [0.],
       [0.],
       [1.]], dtype=float32)

In [32]:
state_ = env.state_seq[146]
objects = jnp.vstack((state_.agent_pos, state_.landmark_pos))

distances = jnp.linalg.norm(state_.agent_pos[:, None, :] - objects[None, :, :], axis=-1)

In [33]:
nearest_dists, nearest_ids = jax.lax.top_k(-distances, 2) # (num_agents, 2)

In [34]:
nearest_ids[:, 1]

Array([183,   4,   1,  10,   1], dtype=int32)

In [106]:
0.01 / env._env.agent_rad

0.078125

In [107]:
- nearest_dists[:, 1] - (env._env.agent_rad + env._env.landmark_rad) * 1.05

Array([-0.01004225, -0.01014116, -0.01172379, -0.01026562,  0.7889613 ],      dtype=float32)

In [None]:
(env.state_seq[123].agent_pos[4, :] - env.state_seq[123].agent_pos[1, :]) / env._env.window

Array([ 0.37683833, -0.9155774 ], dtype=float32)

In [69]:
out["next", "agents", "observation"][0, 0, 3, :]

tensor([-0.3944,  0.0025, -0.0430,  0.2532,  0.0570,  0.2532, -0.1430,  0.2532,
         0.1570,  0.2532, -0.2430,  0.2532,  0.2570,  0.2532, -0.3430,  0.2532,
         0.3570,  0.2532], device='cuda:0')

In [66]:
env.state_seq[381].agent_pos

Array([[ 1.8263538 , -0.54345036],
       [ 1.4484562 ,  1.0495648 ],
       [ 0.6185359 , -1.3465694 ],
       [ 0.6263411 , -0.8824417 ],
       [-0.58231133,  0.66071725]], dtype=float32)

In [67]:
env.state_seq[381].landmark_pos

Array([[-2.        ,  2.        ],
       [-1.9200001 ,  2.        ],
       [-1.8399999 ,  2.        ],
       [-1.76      ,  2.        ],
       [-1.6800001 ,  2.        ],
       [-1.5999999 ,  2.        ],
       [-1.52      ,  2.        ],
       [-1.44      ,  2.        ],
       [-1.3600001 ,  2.        ],
       [-1.28      ,  2.        ],
       [-1.2       ,  2.        ],
       [-1.1199999 ,  2.        ],
       [-1.04      ,  2.        ],
       [-0.96000004,  2.        ],
       [-0.88000005,  2.        ],
       [-0.80000013,  2.        ],
       [-0.72      ,  2.        ],
       [-0.6399999 ,  2.        ],
       [-0.56      ,  2.        ],
       [-0.48000002,  2.        ],
       [-0.4000001 ,  2.        ],
       [-0.3200001 ,  2.        ],
       [-0.24000001,  2.        ],
       [-0.16000009,  2.        ],
       [-0.07999998,  2.        ],
       [ 0.        ,  2.        ],
       [ 0.07999992,  2.        ],
       [ 0.15999985,  2.        ],
       [ 0.24000001,

In [27]:
agent_id = 0
out["next", "agents", "observation"][0, 0, agent_id, :]

tensor([-1.9973,  2.7944,  0.7533,  0.0430,  0.7533, -0.0570,  0.7533,  0.1430,
         0.7533, -0.1570,  0.7533,  0.2430,  0.7533, -0.2570,  0.7533,  0.3430,
         0.7533, -0.3570], device='cuda:0')

In [28]:
env.state_seq[0].agent_pos

Array([[ 1.3973355 , -0.99443763],
       [-1.7930828 , -0.1948089 ],
       [-1.4070048 ,  1.7959106 ],
       [-0.1949299 , -0.5935761 ],
       [ 0.5929402 , -0.60317606]], dtype=float32)

In [29]:
out["next", "agents", "episode_reward"][0, :, 2, 0]

tensor([-3.9302e-03, -7.7401e-03, -1.1426e-02, -1.4990e-02, -1.8433e-02,
        -2.1754e-02, -2.4954e-02, -2.8036e-02, -3.1001e-02, -3.3852e-02,
        -3.6590e-02, -3.9212e-02, -4.1716e-02, -4.4100e-02, -4.6370e-02,
        -4.8530e-02, -5.0574e-02, -5.2499e-02, -5.4303e-02, -5.5999e-02,
        -5.7591e-02, -5.9075e-02, -6.0465e-02, -6.1768e-02, -6.2964e-02,
        -6.4053e-02, -6.5035e-02, -6.5901e-02, -6.6642e-02, -6.7256e-02,
        -6.7751e-02,  9.3186e-01,  1.9316e+00,  2.9313e+00,  3.9311e+00,
         4.9310e+00,  5.9309e+00,  6.9308e+00,  7.9308e+00,  8.9306e+00,
         9.9304e+00,  1.0930e+01,  1.1930e+01,  1.2930e+01,  1.3930e+01,
         1.4929e+01,  1.5929e+01,  1.6929e+01,  1.7929e+01,  1.8928e+01,
         1.9928e+01,  2.0928e+01,  2.1928e+01,  2.2928e+01,  2.3927e+01,
         2.4927e+01,  2.5927e+01,  2.6927e+01,  2.7927e+01,  2.8927e+01,
         2.9927e+01,  3.0926e+01,  3.1926e+01,  3.2926e+01,  3.3926e+01,
         3.4926e+01,  3.5926e+01,  3.6926e+01,  3.7

In [30]:
out["next", "agents", "reward"].shape

torch.Size([120, 900, 5, 1])

In [31]:
out["next", "agents", "observation"][0, 122, 0, :]

tensor([-1.5092,  1.5685,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000], device='cuda:0')

In [32]:
out["next", "agents", "reward"][0, :, :, 0].max(axis=0)

torch.return_types.max(
values=tensor([-0.0021, -0.0084,  0.9999,  0.9998,  0.9998], device='cuda:0'),
indices=tensor([310, 744, 364, 124, 146], device='cuda:0'))

In [33]:
out["agents", "action"][0][-3][0]

tensor([-0.1659, -0.6391], device='cuda:0')

In [34]:
out["next", "agents", "observation"][0][-3]

tensor([[ 0.2785,  0.1100, -0.0019,  0.3875,  0.0981,  0.3875, -0.1019,  0.3875,
          0.1981,  0.3875, -0.2019,  0.3875,  0.2981,  0.3875, -0.3019,  0.3875,
          0.3981,  0.3875],
        [ 0.7557,  0.3890, -0.0054,  0.7363,  0.0946,  0.7363, -0.1054,  0.7363,
          0.1946,  0.7363, -0.2054,  0.7363,  0.2946,  0.7363, -0.3054,  0.7363,
          0.3946,  0.7363],
        [ 0.0100,  0.0054, -0.2375, -0.0432, -0.2375,  0.0568, -0.0375,  0.2568,
          0.0625,  0.2568, -0.2375, -0.1432, -0.2375,  0.1568, -0.1375,  0.2568,
          0.1625,  0.2568],
        [ 0.0115, -0.0155,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0115, -0.0155,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000]], device='cuda:0')

In [None]:
test_env = Grid_Maze(
		width=width,
		height=height,
		obstacle_density=obstacle_density,
		num_agents=num_agents,
		grain_factor=grain_factor,
		contact_force=500,
		contact_margin=1e-3,
		dt=0.01,
		max_steps=300,
		frameskip=7,
		max_obs=env._env.max_obs,
)

last_state = env.state_seq[-1]
obs = test_env.get_obs(last_state.agent_pos, last_state.landmark_pos, last_state.goal_pos)
obs.shape

(5, 18)

In [36]:
obs

Array([[ 0.2809835 ,  0.11932576,  0.0012292 ,  0.39915726, -0.09877078,
         0.39915726,  0.1012291 ,  0.39915726, -0.19877069,  0.39915726,
         0.20122923,  0.39915726, -0.2987706 ,  0.39915726,  0.30122936,
         0.39915726, -0.3987708 ,  0.39915726],
       [ 0.7566569 ,  0.38993514, -0.00417918,  0.737419  ,  0.09582102,
         0.737419  , -0.10417908,  0.737419  ,  0.19582093,  0.737419  ,
        -0.20417899,  0.737419  ,  0.29582083,  0.737419  , -0.3041792 ,
         0.737419  ,  0.39582103,  0.737419  ],
       [ 0.00546694, -0.00443459, -0.24316639,  0.04445672, -0.04316628,
         0.24445683, -0.24316639, -0.05554318,  0.05683362,  0.24445683,
        -0.24316639,  0.14445662, -0.14316648,  0.24445683, -0.24316639,
        -0.15554339,  0.15683353,  0.24445683],
       [ 0.01151991, -0.01554318,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.   

In [37]:
(obs == - 100).sum()

Array(0, dtype=int32)

In [38]:
out[0, -1].get(("agents", "observation"))

tensor([[ 0.2810,  0.1193,  0.0012,  0.3992, -0.0988,  0.3992,  0.1012,  0.3992,
         -0.1988,  0.3992,  0.2012,  0.3992, -0.2988,  0.3992,  0.3012,  0.3992,
         -0.3988,  0.3992],
        [ 0.7567,  0.3899, -0.0042,  0.7374,  0.0958,  0.7374, -0.1042,  0.7374,
          0.1958,  0.7374, -0.2042,  0.7374,  0.2958,  0.7374, -0.3042,  0.7374,
          0.3958,  0.7374],
        [ 0.0055, -0.0044, -0.2432,  0.0445, -0.0432,  0.2445, -0.2432, -0.0555,
          0.0568,  0.2445, -0.2432,  0.1445, -0.1432,  0.2445, -0.2432, -0.1555,
          0.1568,  0.2445],
        [ 0.0115, -0.0155,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0115, -0.0155,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000]], device='cuda:0')

In [39]:
(out[0, -1].get(("agents", "observation")) == -100).sum()

tensor(0, device='cuda:0')