In [2]:
import numpy as np
from plangym import AtariEnvironment, ParallelEnvironment

from fragile.atari.env import AtariEnv
from fragile.core.dt_sampler import GaussianDt
from fragile.core.env import DiscreteEnv
from fragile.core.models import RandomDiscrete
from fragile.core.states import States
from fragile.core.swarm import Swarm
from fragile.core.tree import HistoryTree
from fragile.core.walkers import Walkers


env = ParallelEnvironment(
        env_class=AtariEnvironment,
        name="MsPacman-ram-v0",
        clone_seeds=True,
        autoreset=True,
        blocking=False,
    )



In [8]:
env.name

'MsPacman-ram-v0'

In [None]:
dt = GaussianDt(min_dt=3, max_dt=1000, loc_dt=4, scale_dt=2)
swarm = Swarm(
    model=lambda x: RandomDiscrete(x, dt_sampler=dt),
    walkers=Walkers,
    env=lambda: AtariEnv(env),
    n_walkers=64,
    max_iters=10000,
    prune_tree=True,
    reward_scale=2,
    minimize=False,
    tree=HistoryTree,
    use_tree=True,
)

_ = swarm.run_swarm(print_every=50)

In [6]:
best_ix = swarm.walkers.states.cum_rewards.argmax()
best = swarm.walkers.states.id_walkers[best_ix]
path = swarm.tree.get_branch(best, from_hash=True)

import time
for s, a in zip(path[0][1:], path[1]):
    env.step(state=s, action=a)
    env.render()
    time.sleep(0.05)

In [None]:
state, obs = env.reset()

states = [state.copy() for _ in range(10)]
actions = [env.action_space.sample() for _ in range(10)]

data = env.step_batch(states=states, actions=actions)
new_states, observs, rewards, ends, infos = data