In [1]:
import jax
from jax import numpy as jnp
from typing import Tuple
import chex

key = jax.random.key(0)

In [2]:
from jumanji_env.environments.complex_orchard.constants import (
    TREE_DISTANCE_ROW,
    TREE_DISTANCE_COL,
    TREE_VARIATION,
    TREE_DIAMETER,
    ORCHARD_FERTILITY,
    ROBOT_DIAMETER,
    BASKET_DIAMETER,
    APPLE_DIAMETER,
    APPLE_DENSITY,
)

from jumanji_env.environments.complex_orchard.generator import ComplexOrchardGenerator
gen = ComplexOrchardGenerator(width=2000, height=1600, num_picker_bots=4)

In [3]:
state = gen.sample_orchard(key)
state

ComplexOrchardState(bots=ComplexOrchardBot(id=Array([0, 1, 2, 3], dtype=int32), position=Array([[ 400.,  320.],
       [ 800.,  320.],
       [1200.,  320.],
       [1600.,  320.]], dtype=float32), diameter=Array([60, 60, 60, 60], dtype=int32, weak_type=True), holding=Array([-1, -1, -1, -1], dtype=int32, weak_type=True), job=Array([0., 0., 0., 0.], dtype=float32), orientation=Array([0., 0., 0., 0.], dtype=float32)), trees=ComplexOrchardTree(id=Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44], dtype=int32), position=Array([[ 288.3587 ,  588.271  ],
       [ 306.3193 ,  784.2533 ],
       [ 301.2496 ,  905.0596 ],
       [ 280.61194, 1075.1935 ],
       [ 260.1008 , 1259.1196 ],
       [ 450.77747,  610.0664 ],
       [ 454.6857 ,  781.78217],
       [ 446.41635,  931.7907 ],
       [ 445.29117, 1077.7565 ],
       [ 463.55804, 1253.0273 

In [4]:
from jumanji_env.environments.complex_orchard.observer import BasicObserver

observer = BasicObserver(4, 2000, 1600, 4)

observer.state_to_observation(state)

  from .autonotebook import tqdm as notebook_tqdm


ComplexOrchardObservation(agents_view=Array([[-77.46457  ,  95.35062  ,   2.2530634],
       [ 18.707031 ,  96.351654 ,   1.3790286],
       [-83.630005 , 112.17703  ,   2.2114232],
       [-21.97522  , 152.72845  ,   1.7136996]], dtype=float32), action_mask=Array([[ True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True]], dtype=bool), step_count=Array(0, dtype=int32))

In [5]:
from jumanji_env.environments.complex_orchard.utils import bots_possible_moves

new_positions = bots_possible_moves(state)
new_positions

Array([[[4.030e+02, 3.200e+02, 1.000e+00],
        [3.970e+02, 3.200e+02, 1.000e+00]],

       [[8.030e+02, 3.200e+02, 1.000e+00],
        [7.970e+02, 3.200e+02, 1.000e+00]],

       [[1.203e+03, 3.200e+02, 1.000e+00],
        [1.197e+03, 3.200e+02, 1.000e+00]],

       [[1.603e+03, 3.200e+02, 1.000e+00],
        [1.597e+03, 3.200e+02, 1.000e+00]]], dtype=float32)

In [6]:
from jumanji_env.environments.complex_orchard.env import ComplexOrchard

env = ComplexOrchard(generator=gen)

state, timestep = env.reset(key)

key passed to env.reset: Array((), dtype=key<fry>) overlaying:
[0 0], type: <class 'jax._src.prng.PRNGKeyArray'>


In [7]:
from jumanji_env.environments.complex_orchard.constants import (
    NOOP,
    FORWARD,
    BACKWARD,
    LEFT,
    RIGHT,
)

action = jnp.repeat(NOOP, 4)
env.step(state, action)

(ComplexOrchardState(bots=ComplexOrchardBot(id=Array([0, 1, 2, 3], dtype=int32), position=Array([[ 400.,  320.],
        [ 800.,  320.],
        [1200.,  320.],
        [1600.,  320.]], dtype=float32), diameter=Array([60, 60, 60, 60], dtype=int32, weak_type=True), holding=Array([-1, -1, -1, -1], dtype=int32, weak_type=True), job=Array([0., 0., 0., 0.], dtype=float32), orientation=Array([0., 0., 0., 0.], dtype=float32)), trees=ComplexOrchardTree(id=Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
        34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44], dtype=int32), position=Array([[ 288.3587 ,  588.271  ],
        [ 306.3193 ,  784.2533 ],
        [ 301.2496 ,  905.0596 ],
        [ 280.61194, 1075.1935 ],
        [ 260.1008 , 1259.1196 ],
        [ 450.77747,  610.0664 ],
        [ 454.6857 ,  781.78217],
        [ 446.41635,  931.7907 ],
        [ 445.29117, 1077.7565 ],
        [ 463.55

In [8]:
# Now try stacking multiple environments together

keys = jax.random.split(key, 5)
env_states, timesteps = jax.vmap(env.reset, in_axes=(0))(
    jnp.stack(keys),
)

key passed to env.reset: Traced<ShapedArray(key<fry>[])>with<BatchTrace(level=1/0)> with
  val = Array((5,), dtype=key<fry>) overlaying:
[[1524306142 2595015335]
 [1836460763  990488084]
 [1416732029 1887795613]
 [1078027127 1191019179]
 [3297765038 3069809391]]
  batch_dim = 0, type: <class 'jax._src.interpreters.batching.BatchTracer'>


In [9]:
# Test out having different environments getting different actions
different_actions = [NOOP, FORWARD, BACKWARD, LEFT, RIGHT]
actions = jnp.array([[a for _ in range(4)] for a in different_actions])

new_env_states, new_timesteps = jax.vmap(env.step, in_axes=(0, 0))(env_states, actions)

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[4])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError