In [1]:
import jax
from jax import numpy as jnp
from typing import Tuple
import chex

key = jax.random.key(0)

In [2]:
from jumanji_env.environments.complex_orchard.constants import (
    TREE_DISTANCE_ROW,
    TREE_DISTANCE_COL,
    TREE_VARIATION,
    TREE_DIAMETER,
    ORCHARD_FERTILITY,
    ROBOT_DIAMETER,
    BASKET_DIAMETER,
    APPLE_DIAMETER,
    APPLE_DENSITY,
)

from jumanji_env.environments.complex_orchard.generator import ComplexOrchardGenerator
gen = ComplexOrchardGenerator(width=2000, height=1600, num_picker_bots=4)

In [3]:
state = gen.sample_orchard(key)
state

ComplexOrchardState(bots=ComplexOrchardBot(id=Array([0, 1, 2, 3], dtype=int32), position=Array([[ 400.    ,  213.1386],
       [ 800.    ,  213.1386],
       [1200.    ,  213.1386],
       [1600.    ,  213.1386]], dtype=float32), diameter=Array([60, 60, 60, 60], dtype=int32, weak_type=True), holding=Array([-1, -1, -1, -1], dtype=int32, weak_type=True), job=Array([0., 0., 0., 0.], dtype=float32), orientation=Array([0., 0., 0., 0.], dtype=float32)), trees=ComplexOrchardTree(id=Array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,

In [4]:
from jumanji_env.environments.complex_orchard.observer import BasicObserver

observer = BasicObserver(4, 2000, 1600)

observer.state_to_observation(state)

  from .autonotebook import tqdm as notebook_tqdm


ValueError: Incompatible shapes for broadcasting: (4, 2) and requested shape (4, 1)

In [5]:
from jumanji_env.environments.complex_orchard.utils import bots_possible_moves

new_positions = bots_possible_moves(state)
new_positions

Array([[[4.030000e+02, 2.131386e+02, 1.000000e+00],
        [3.970000e+02, 2.131386e+02, 1.000000e+00]],

       [[8.030000e+02, 2.131386e+02, 1.000000e+00],
        [7.970000e+02, 2.131386e+02, 1.000000e+00]],

       [[1.203000e+03, 2.131386e+02, 1.000000e+00],
        [1.197000e+03, 2.131386e+02, 1.000000e+00]],

       [[1.603000e+03, 2.131386e+02, 1.000000e+00],
        [1.597000e+03, 2.131386e+02, 1.000000e+00]]], dtype=float32)

In [6]:
len(state.bots.id[state.bots.id == 5])

0

In [7]:
jnp.array(-1)

Array(-1, dtype=int32, weak_type=True)

In [8]:
jnp.repeat(-1, 5)

Array([-1, -1, -1, -1, -1], dtype=int32, weak_type=True)

In [12]:
from jumanji_env.environments.complex_orchard.utils import distances_between_entities


jnp.argmin(distances_between_entities(state.bots.position, state.bots.position), axis=1)

Array([0, 1, 2, 3], dtype=int32)

In [13]:
nose_positions = state.bots.position + jnp.stack([jnp.cos(state.bots.orientation), jnp.sin(state.bots.orientation)], axis=1) * state.bots.diameter / 2

ValueError: Incompatible shapes for broadcasting: shapes=[(4, 2), (4,)]

In [17]:
jnp.stack([jnp.cos(state.bots.orientation) * state.bots.diameter, jnp.sin(state.bots.orientation) * state.bots.diameter], axis=1) / 2


Array([[30.,  0.],
       [30.,  0.],
       [30.,  0.],
       [30.,  0.]], dtype=float32)

In [18]:
2 in jnp.arange(4)

True

In [20]:
state.bots.id in jnp.array([2, 4, 8])

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [24]:
jax.vmap(lambda id, targets: jnp.any(id == targets), in_axes=(0, None))(state.bots.id, jnp.array([0, 2, 8]))

Array([ True, False,  True, False], dtype=bool)

In [26]:

state.bots.id.at[jax.vmap(lambda id, targets: jnp.any(id == targets), in_axes=(0, None))(state.bots.id, jnp.array([0, 2, 8]))].set(6)


Array([6, 1, 6, 3], dtype=int32)