In [1]:
import jax
from jax import numpy as jnp
from typing import Tuple
import chex

key = jax.random.key(0)

In [2]:
from jumanji_env.environments.complex_orchard.constants import (
    TREE_DISTANCE_ROW,
    TREE_DISTANCE_COL,
    TREE_VARIATION,
    TREE_DIAMETER,
    ORCHARD_FERTILITY,
    ROBOT_DIAMETER,
    BASKET_DIAMETER,
    APPLE_DIAMETER,
    APPLE_DENSITY,
)

from jumanji_env.environments.complex_orchard.generator import ComplexOrchardGenerator
gen = ComplexOrchardGenerator(width=2000, height=1600, num_picker_bots=4)

In [3]:
state = gen.sample_orchard(key)
state

ComplexOrchardState(bots=ComplexOrchardBot(id=Array([0, 1, 2, 3], dtype=int32), position=Array([[ 400.,  200.],
       [ 800.,  200.],
       [1200.,  200.],
       [1600.,  200.]], dtype=float32), diameter=Array([60, 60, 60, 60], dtype=int32, weak_type=True), holding=Array([-1, -1, -1, -1], dtype=int32, weak_type=True), job=Array([0., 0., 0., 0.], dtype=float32), orientation=Array([0., 0., 0., 0.], dtype=float32)), trees=ComplexOrchardTree(id=Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int32), position=Array([[ 164.54881,  226.81322],
       [ 158.36102,  531.16876],
       [ 164.56766,  851.78033],
       [ 180.78163, 1191.3467 ],
       [ 174.2184 , 1496.9141 ],
       [ 495.22784,  219.73045],
       [ 488.48328,  527.3298 ],
       [ 476.9649 ,  863.4358 ],
       [ 493.47452, 1190.8564 ],
       [ 495.1007 , 1499.6929 ],
       [ 821.54376,  228.9994 ],
       [ 837.8113 ,  544.5668 ]

In [4]:
from jumanji_env.environments.complex_orchard.observer import BasicObserver

observer = BasicObserver(4, 2000, 1600, 4)

observer.state_to_observation(state)

  from .autonotebook import tqdm as notebook_tqdm


ComplexOrchardObservation(agents_view=Array([[ 12.440796  ,  -6.9968567 ,  -0.5123229 ],
       [-16.079834  , -24.022705  ,  -2.1606612 ],
       [ -1.064331  ,   0.16418457,   2.988538  ],
       [  1.201294  ,  10.579575  ,   1.4577324 ]], dtype=float32), action_mask=Array([[ True,  True,  True,  True,  True,  True,  True],
       [ True, False, False,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True]], dtype=bool), step_count=Array(0, dtype=int32))

In [5]:
from jumanji_env.environments.complex_orchard.utils import bots_possible_moves

new_positions = bots_possible_moves(state)
new_positions

Array([[[4.030e+02, 2.000e+02, 1.000e+00],
        [3.970e+02, 2.000e+02, 1.000e+00]],

       [[8.030e+02, 2.000e+02, 0.000e+00],
        [7.970e+02, 2.000e+02, 0.000e+00]],

       [[1.203e+03, 2.000e+02, 1.000e+00],
        [1.197e+03, 2.000e+02, 1.000e+00]],

       [[1.603e+03, 2.000e+02, 1.000e+00],
        [1.597e+03, 2.000e+02, 1.000e+00]]], dtype=float32)

In [6]:
from jumanji_env.environments.complex_orchard.env import ComplexOrchard

env = ComplexOrchard(generator=gen)

state, timestep = env.reset(key)

In [7]:
from jumanji_env.environments.complex_orchard.constants import (
    NOOP,
    FORWARD,
    BACKWARD,
    LEFT,
    RIGHT,
)

action = jnp.repeat(NOOP, 4)
env.step(state, action)

(ComplexOrchardState(bots=ComplexOrchardBot(id=Array([0, 1, 2, 3], dtype=int32), position=Array([[ 400.,  200.],
        [ 800.,  200.],
        [1200.,  200.],
        [1600.,  200.]], dtype=float32), diameter=Array([60, 60, 60, 60], dtype=int32, weak_type=True), holding=Array([-1, -1, -1, -1], dtype=int32, weak_type=True), job=Array([0., 0., 0., 0.], dtype=float32), orientation=Array([0., 0., 0., 0.], dtype=float32)), trees=ComplexOrchardTree(id=Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int32), position=Array([[ 164.54881,  226.81322],
        [ 158.36102,  531.16876],
        [ 164.56766,  851.78033],
        [ 180.78163, 1191.3467 ],
        [ 174.2184 , 1496.9141 ],
        [ 495.22784,  219.73045],
        [ 488.48328,  527.3298 ],
        [ 476.9649 ,  863.4358 ],
        [ 493.47452, 1190.8564 ],
        [ 495.1007 , 1499.6929 ],
        [ 821.54376,  228.9994 ],
        [ 837.81

In [8]:
# Now try stacking multiple environments together

keys = jax.random.split(key, 5)
env_states, timesteps = jax.vmap(env.reset, in_axes=(0))(
    jnp.stack(keys),
)

In [9]:
# Test out having different environments getting different actions
different_actions = [NOOP, FORWARD, BACKWARD, LEFT, RIGHT]
actions = jnp.array([[a for _ in range(4)] for a in different_actions])

new_env_states, new_timesteps = jax.vmap(env.step, in_axes=(0, 0))(env_states, actions)