In [1]:
import jax
from jax import numpy as jnp
from typing import Tuple
import chex

key = jax.random.key(0)

In [2]:
from jumanji_env.environments.complex_orchard.constants import (
    TREE_DISTANCE_ROW,
    TREE_DISTANCE_COL,
    TREE_VARIATION,
    TREE_DIAMETER,
    ORCHARD_FERTILITY,
    ROBOT_DIAMETER,
    BASKET_DIAMETER,
    APPLE_DIAMETER,
    APPLE_DENSITY,
)

from jumanji_env.environments.complex_orchard.generator import ComplexOrchardGenerator
gen = ComplexOrchardGenerator(width=2000, height=1600, num_picker_bots=4)

In [3]:
state = gen.sample_orchard(key)
state

ComplexOrchardState(bots=ComplexOrchardBot(id=Array([0, 1, 2, 3], dtype=int32), position=Array([[ 400.,  200.],
       [ 800.,  200.],
       [1200.,  200.],
       [1600.,  200.]], dtype=float32), diameter=Array([60, 60, 60, 60], dtype=int32, weak_type=True), holding=Array([-1, -1, -1, -1], dtype=int32, weak_type=True), job=Array([0., 0., 0., 0.], dtype=float32), orientation=Array([0., 0., 0., 0.], dtype=float32)), trees=ComplexOrchardTree(id=Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int32), position=Array([[ 164.54881,  226.81322],
       [ 158.36102,  531.16876],
       [ 164.56766,  851.78033],
       [ 180.78163, 1191.3467 ],
       [ 174.2184 , 1496.9141 ],
       [ 495.22784,  219.73045],
       [ 488.48328,  527.3298 ],
       [ 476.9649 ,  863.4358 ],
       [ 493.47452, 1190.8564 ],
       [ 495.1007 , 1499.6929 ],
       [ 821.54376,  228.9994 ],
       [ 837.8113 ,  544.5668 ]

In [4]:
from jumanji_env.environments.complex_orchard.observer import BasicObserver

observer = BasicObserver(4, 2000, 1600, 4)

observer.state_to_observation(state)

  from .autonotebook import tqdm as notebook_tqdm


ComplexOrchardObservation(agents_view=Array([[ 12.440796  ,  -6.9968567 ,  -0.5123229 ],
       [-16.079834  , -24.022705  ,  -2.1606612 ],
       [ -1.064331  ,   0.16418457,   2.988538  ],
       [  1.201294  ,  10.579575  ,   1.4577324 ]], dtype=float32), action_mask=Array([[ True,  True,  True,  True,  True,  True,  True],
       [ True, False, False,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True]], dtype=bool), step_count=Array(0, dtype=int32))

In [5]:
from jumanji_env.environments.complex_orchard.utils import bots_possible_moves

new_positions = bots_possible_moves(state)
new_positions

(Array([[[4.030e+02, 2.000e+02, 1.000e+00],
         [3.970e+02, 2.000e+02, 1.000e+00]],
 
        [[8.030e+02, 2.000e+02, 0.000e+00],
         [7.970e+02, 2.000e+02, 0.000e+00]],
 
        [[1.203e+03, 2.000e+02, 1.000e+00],
         [1.197e+03, 2.000e+02, 1.000e+00]],
 
        [[1.603e+03, 2.000e+02, 1.000e+00],
         [1.597e+03, 2.000e+02, 1.000e+00]]], dtype=float32),
 Array([False, False, False, False], dtype=bool),
 Array([False,  True, False, False], dtype=bool))

In [6]:
from jumanji_env.environments.complex_orchard.env import ComplexOrchard

env = ComplexOrchard(generator=gen)

state, timestep = env.reset(key)

In [7]:
from jumanji_env.environments.complex_orchard.constants import (
    NOOP,
    FORWARD,
    BACKWARD,
    LEFT,
    RIGHT,
)

action = jnp.repeat(NOOP, 4)
env.step(state, action)

(ComplexOrchardState(bots=ComplexOrchardBot(id=Array([0, 1, 2, 3], dtype=int32), position=Array([[ 400.,  200.],
        [ 800.,  200.],
        [1200.,  200.],
        [1600.,  200.]], dtype=float32), diameter=Array([60, 60, 60, 60], dtype=int32, weak_type=True), holding=Array([-1, -1, -1, -1], dtype=int32, weak_type=True), job=Array([0., 0., 0., 0.], dtype=float32), orientation=Array([0., 0., 0., 0.], dtype=float32)), trees=ComplexOrchardTree(id=Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int32), position=Array([[ 164.54881,  226.81322],
        [ 158.36102,  531.16876],
        [ 164.56766,  851.78033],
        [ 180.78163, 1191.3467 ],
        [ 174.2184 , 1496.9141 ],
        [ 495.22784,  219.73045],
        [ 488.48328,  527.3298 ],
        [ 476.9649 ,  863.4358 ],
        [ 493.47452, 1190.8564 ],
        [ 495.1007 , 1499.6929 ],
        [ 821.54376,  228.9994 ],
        [ 837.81

In [44]:
# Now try stacking multiple environments together

keys = jax.random.split(key, 5)
env_states, timesteps = jax.vmap(env.reset, in_axes=(0))(
    jnp.stack(keys),
)

In [45]:
# Test out having different environments getting different actions
different_actions = [NOOP, FORWARD, BACKWARD, LEFT, RIGHT]
actions = jnp.array([[a for _ in range(4)] for a in different_actions])

new_env_states, new_timesteps = jax.vmap(env.step, in_axes=(0, 0))(env_states, actions)

In [10]:
# This takes a LONG time to run!

from jumanji_env.environments.complex_orchard.astar_solver import AStarSolver
from jumanji_env.environments.complex_orchard.constants import TICK_SPEED
import time

eval_gen = ComplexOrchardGenerator(width=500, height=500, num_picker_bots=2)

NUM_EVALS = 10
avg_collected = 0
start_time = time.time()

for i in range(NUM_EVALS):
    seed = i + 10
    solver = AStarSolver(eval_gen, seed)
    percent_collected = solver.simulate(1500)
    
    print(f'Seed {seed} collected {percent_collected:.2f}%')
    avg_collected += percent_collected / NUM_EVALS

wall_time = time.time() - start_time
print(f'A Star collected {avg_collected:.2f}% ({wall_time} wall time) in 2.5 simulated minutes')

Seed 0 collected 51.43%
Seed 1 collected 54.29%
Seed 2 collected 48.57%
Seed 3 collected 51.43%
Seed 4 collected 51.43%
Seed 5 collected 54.29%
Seed 6 collected 54.29%
Seed 7 collected 54.29%
Seed 8 collected 54.29%
Seed 9 collected 11.43%
A Star collected 48.57% (2497.0341551303864 wall time) in 2.5 simulated minutes


In [11]:
import json
from simulation.orchard import OrchardComplex2D

stored = '{"width": 500, "height": 500, "seed": 0, "time": 4088, "bots": [{"x": 210.5732879638672, "y": 177.29385375976562, "diameter": 60, "holding": 12, "job": "picker", "orientation": -1.0471975803375244}, {"x": 299.6589050292969, "y": 197.19842529296875, "diameter": 60, "holding": 9, "job": "picker", "orientation": -2.094395160675049}], "trees": [{"x": 219.70751953125, "y": 334.7588195800781, "diameter": 40.51554870605469}], "baskets": [{"x": 250.0, "y": 100.0, "diameter": 100, "held": false, "collected": false}], "apples": [{"x": 292.05035400390625, "y": 138.4942169189453, "diameter": 5.224129676818848, "held": false, "collected": true}, {"x": 282.29742431640625, "y": 146.67076110839844, "diameter": 5.918465614318848, "held": false, "collected": true}, {"x": 277.4336853027344, "y": 150.314697265625, "diameter": 4.143068313598633, "held": false, "collected": true}, {"x": 276.6600036621094, "y": 150.88446044921875, "diameter": 6.100149154663086, "held": false, "collected": true}, {"x": 277.0805358886719, "y": 149.37477111816406, "diameter": 4.908559799194336, "held": false, "collected": true}, {"x": 266.8163757324219, "y": 156.47422790527344, "diameter": 6.466621398925781, "held": false, "collected": true}, {"x": 277.7527160644531, "y": 150.7341766357422, "diameter": 4.312686920166016, "held": false, "collected": true}, {"x": 266.75421142578125, "y": 155.98068237304688, "diameter": 5.1193647384643555, "held": false, "collected": true}, {"x": 269.9568176269531, "y": 152.66725158691406, "diameter": 5.868011474609375, "held": false, "collected": true}, {"x": 287.4568176269531, "y": 169.7920684814453, "diameter": 5.809104919433594, "held": true, "collected": false}, {"x": 296.4330749511719, "y": 136.27197265625, "diameter": 3.4358057975769043, "held": false, "collected": true}, {"x": 271.32916259765625, "y": 152.822265625, "diameter": 5.647130489349365, "held": false, "collected": true}, {"x": 224.0732879638672, "y": 153.9111785888672, "diameter": 4.892746925354004, "held": true, "collected": false}, {"x": 224.9587860107422, "y": 152.01971435546875, "diameter": 5.447516918182373, "held": false, "collected": true}, {"x": 266.99237060546875, "y": 155.0208282470703, "diameter": 5.300501823425293, "held": false, "collected": true}, {"x": 222.32864379882812, "y": 150.56060791015625, "diameter": 5.968918800354004, "held": false, "collected": true}, {"x": 272.2357177734375, "y": 154.35313415527344, "diameter": 6.829443454742432, "held": false, "collected": true}, {"x": 222.12884521484375, "y": 151.05352783203125, "diameter": 5.560314655303955, "held": false, "collected": true}, {"x": 224.972900390625, "y": 152.6625518798828, "diameter": 5.3172478675842285, "held": false, "collected": true}, {"x": 219.2754364013672, "y": 147.6256103515625, "diameter": 4.761844635009766, "held": false, "collected": true}, {"x": 224.59046936035156, "y": 152.1255645751953, "diameter": 4.875090599060059, "held": false, "collected": true}, {"x": 218.17922973632812, "y": 148.5032196044922, "diameter": 4.25541877746582, "held": false, "collected": true}, {"x": 215.25230407714844, "y": 145.820068359375, "diameter": 4.428750038146973, "held": false, "collected": true}, {"x": 220.0056915283203, "y": 149.2542724609375, "diameter": 5.080661773681641, "held": false, "collected": true}, {"x": 214.275146484375, "y": 146.63134765625, "diameter": 4.319875240325928, "held": false, "collected": true}, {"x": 217.0486602783203, "y": 147.61203002929688, "diameter": 5.780456066131592, "held": false, "collected": true}, {"x": 210.30020141601562, "y": 142.0749053955078, "diameter": 4.1222710609436035, "held": false, "collected": true}, {"x": 205.52023315429688, "y": 138.11996459960938, "diameter": 3.4948997497558594, "held": false, "collected": true}, {"x": 212.85874938964844, "y": 142.03475952148438, "diameter": 4.4725799560546875, "held": false, "collected": true}, {"x": 218.28762817382812, "y": 147.01902770996094, "diameter": 5.3883137702941895, "held": false, "collected": true}, {"x": 250.8977813720703, "y": 157.10797119140625, "diameter": 4.282293319702148, "held": false, "collected": true}, {"x": 263.97723388671875, "y": 155.98333740234375, "diameter": 5.383018970489502, "held": false, "collected": true}, {"x": 253.28762817382812, "y": 158.221435546875, "diameter": 2.981675148010254, "held": false, "collected": true}, {"x": 263.97723388671875, "y": 155.98333740234375, "diameter": 4.358050346374512, "held": false, "collected": true}, {"x": 260.3659362792969, "y": 157.31910705566406, "diameter": 4.419834613800049, "held": false, "collected": true}], "TICK_SPEED": 10}'
stored_dict = json.loads(stored)

print(len([
    i
    for i, apple in enumerate(stored_dict['apples'])
    if apple['collected']
]))
print(len(stored_dict['apples']))

old_env = OrchardComplex2D(stored_dict['width'], stored_dict['height'], 2, 0, 1, 0)
old_env.trees = stored_dict['trees']
old_env.baskets = stored_dict['baskets']
old_env.apples = stored_dict['apples']
old_env.bots = stored_dict['bots']

old_env.try_pick(0)

33
35


In [8]:
import json
import jax
from apple_mava.log_dicts import create_complex_dict

random_gen = ComplexOrchardGenerator(width=500, height=500, num_picker_bots=2)
random_env = ComplexOrchard(generator=random_gen)

random_state, timestep = random_env.reset(key)
stored_states = [create_complex_dict(random_state, 0)]
current_key = key

for i in range(1500):
    if i % 100 == 0:
        print(i)
        
    actions = jax.random.randint(current_key, (2,), 0, 6)
    random_state, timestep = random_env.step(random_state, actions)
    stored_states.append(create_complex_dict(random_state, 0))
    
    current_key = jax.random.split(current_key, 1)[0]
    
with open('random.json', 'w') as f:
    f.write(json.dumps(stored_states))

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
