In [1]:
import jax
from jax import numpy as jnp
from typing import Tuple
import chex

key = jax.random.key(0)

In [2]:
from jumanji_env.environments.complex_orchard.constants import (
    TREE_DISTANCE_ROW,
    TREE_DISTANCE_COL,
    TREE_VARIATION,
    TREE_DIAMETER,
    ORCHARD_FERTILITY,
    ROBOT_DIAMETER,
    BASKET_DIAMETER,
    APPLE_DIAMETER,
    APPLE_DENSITY,
)

from jumanji_env.environments.complex_orchard.generator import ComplexOrchardGenerator
gen = ComplexOrchardGenerator(width=2000, height=1600)

In [3]:
state = gen.sample_orchard(key)
state

In [4]:
from jumanji_env.environments.complex_orchard.observer import BasicObserver

observer = BasicObserver(4, 2000, 1600)

observer.state_to_observation(state)

  from .autonotebook import tqdm as notebook_tqdm


ComplexOrchardObservation(agents_view=Array([[-59.30652  ,  42.889603 ,   2.515475 ],
       [ -1.3530273, -16.591583 ,  -1.6521653]], dtype=float32), action_mask=Array([[1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1.]], dtype=float32), time=Array(0, dtype=int32))

In [5]:
from jumanji_env.environments.complex_orchard.utils import bots_possible_moves

bots_possible_moves(state)

Array([[[6.6966669e+02, 2.1313860e+02, 1.0000000e+00],
        [6.6366669e+02, 2.1313860e+02, 1.0000000e+00]],

       [[1.3363334e+03, 2.1313860e+02, 1.0000000e+00],
        [1.3303334e+03, 2.1313860e+02, 1.0000000e+00]]], dtype=float32)

In [6]:
from jumanji_env.environments.complex_orchard.constants import ROBOT_MOVE_SPEED, JaxArray

bots_pos = jnp.concat([state.bots.position, state.bots.position])
bots_orien = jnp.concat([state.bots.orientation, state.bots.orientation])

num_bots = bots_pos.shape[0]
direction: JaxArray['num_bots', 2] = jnp.stack([jnp.cos(bots_orien), jnp.sin(bots_orien)], axis=1) * ROBOT_MOVE_SPEED

# Calculate the new position of all of the bots
new_forward_positions: JaxArray['num_bots', 2] = bots_pos + direction
new_backward_positions: JaxArray['num_bots', 2] = bots_pos - direction

# Make space for the is_possible value
new_forward_positions: JaxArray['num_bots', 3] = jnp.pad(new_forward_positions, (0, 1), constant_values=1)[:-1]
new_backward_positions: JaxArray['num_bots', 3] = jnp.pad(new_backward_positions, (0, 1), constant_values=1)[:-1]


In [7]:
new_positions = jnp.stack([new_forward_positions, new_backward_positions], axis=1)

In [8]:
new_positions

Array([[[6.6966669e+02, 2.1313860e+02, 1.0000000e+00],
        [6.6366669e+02, 2.1313860e+02, 1.0000000e+00]],

       [[1.3363334e+03, 2.1313860e+02, 1.0000000e+00],
        [1.3303334e+03, 2.1313860e+02, 1.0000000e+00]],

       [[6.6966669e+02, 2.1313860e+02, 1.0000000e+00],
        [6.6366669e+02, 2.1313860e+02, 1.0000000e+00]],

       [[1.3363334e+03, 2.1313860e+02, 1.0000000e+00],
        [1.3303334e+03, 2.1313860e+02, 1.0000000e+00]]], dtype=float32)

In [9]:
new_positions[:, :, 0].reshape((num_bots * 2,))

Array([ 669.6667,  663.6667, 1336.3334, 1330.3334,  669.6667,  663.6667,
       1336.3334, 1330.3334], dtype=float32)

In [10]:
direction

Array([[3., 0.],
       [3., 0.],
       [3., 0.],
       [3., 0.]], dtype=float32)

In [11]:
state.bots.position + direction

TypeError: add got incompatible shapes for broadcasting: (2, 2), (4, 2).

In [None]:
state.bots.position - direction

In [None]:
jnp.concat([state.bots.position, state.bots.position])

In [15]:
jnp.linalg.norm(state.bots.position - state.baskets.position, axis=1)

Array([333.3333 , 333.33337], dtype=float32)

In [16]:
state.bots.position.shape

(2, 2)

In [17]:
state.baskets.position.shape

(1, 2)

In [19]:
jax.vmap(lambda target, others: jnp.linalg.norm(target - others, axis=1), in_axes=(0, None))(state.bots.position, state.baskets.position)


Array([[333.3333 ],
       [333.33337]], dtype=float32)