In [5]:
import numpy as np
from plangym import AtariEnvironment, ParallelEnvironment

from fragile.atari.env import AtariEnv
from fragile.core.dt_sampler import GaussianDt
from fragile.core.env import DiscreteEnv
from fragile.core.models import RandomDiscrete
from fragile.core.states import States
from fragile.core.swarm import Swarm
from fragile.core.tree import HistoryTree
from fragile.core.walkers import Walkers


env = ParallelEnvironment(
        env_class=AtariEnvironment,
        name="MsPacman-ram-v0",
        clone_seeds=True,
        autoreset=True,
        blocking=False,
    )

dt = GaussianDt(min_dt=3, max_dt=1000, loc_dt=4, scale_dt=2)
swarm = Swarm(
    model=lambda x: RandomDiscrete(x, dt_sampler=dt),
    walkers=Walkers,
    env=lambda: AtariEnv(env),
    n_walkers=64,
    max_iters=10000,
    prune_tree=True,
    reward_scale=2,
    minimize=False,
    tree=HistoryTree,
    use_tree=True,
)

_ = swarm.run_swarm(print_every=50)


Best reward found: 11811.0000 , efficiency 0.663, Critic: None
Walkers iteration 2101 Best reward: 11811.00 Dead walkers: 15.62% Cloned: 23.44%

Walkers States: 
id_walkers shape (64,) Mean: 248511217039606400.000, Std: 5155967118006030336.000, Max: 8759404405908455424.000 Min: -8939735341124625408.000
compas_dist shape (64,) Mean: 31.500, Std: 18.473, Max: 63.000 Min: 0.000
compas_clone shape (64,) Mean: 33.484, Std: 17.918, Max: 63.000 Min: 0.000
processed_rewards shape (64,) Mean: 1.132, Std: 0.465, Max: 2.345 Min: 0.112
virtual_rewards shape (64,) Mean: 1.067, Std: 0.038, Max: 1.235 Min: 1.000
cum_rewards shape (64,) Mean: 11780.531, Std: 9.426, Max: 11811.000 Min: 11751.000
distances shape (64,) Mean: 1.062, Std: 0.558, Max: 1.978 Min: 0.057
clone_probs shape (64,) Mean: 0.087, Std: 0.103, Max: 0.393 Min: 0.000
will_clone shape (64,) Mean: 0.234, Std: 0.424, Max: 1.000 Min: 0.000
alive_mask shape (64,) Mean: 0.844, Std: 0.363, Max: 1.000 Min: 0.000
end_condition shape (64,) Mean:

In [6]:
best_ix = swarm.walkers.states.cum_rewards.argmax()
best = swarm.walkers.states.id_walkers[best_ix]
path = swarm.tree.get_branch(best, from_hash=True)

import time
for s, a in zip(path[0][1:], path[1]):
    env.step(state=s, action=a)
    env.render()
    time.sleep(0.05)

In [None]:
state, obs = env.reset()

states = [state.copy() for _ in range(10)]
actions = [env.action_space.sample() for _ in range(10)]

data = env.step_batch(states=states, actions=actions)
new_states, observs, rewards, ends, infos = data