In [None]:
import warnings
warnings.filterwarnings("ignore")
from fragile.atari.swarm import MontezumaSwarm
import ray
ray.init()

In [None]:
swarm = MontezumaSwarm.create_swarm(n_walkers=75,
                                    max_iters=1000,
                                    prune_tree=True,
                                    use_tree=False,
                                    reward_scale=4,
                                    dist_scale=0.5,
                                    plot_step=2,
                                    critic_scale=2,
                                   episodic_rewad=True)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline


In [None]:
dmap = swarm.plot_dmap()
dmap

In [None]:
_ = swarm.run_swarm(print_every=1)

In [None]:
swarm.walkers.env_states.states.dtype

In [None]:
import holoviews as hv
hv.extension("matplotlib")
dmap[0].opts(normalize=True)

In [None]:
for x in swarm.walkers.env_states.states.reshape(swarm.walkers.n, -1)[0]:
    print(x)

In [None]:
import numpy as np
np.prod((210,160,3))

In [None]:
swarm.plot_critic()

In [None]:
import time
import numpy as np
import pandas as pd
import holoviews as hv
import streamz
import streamz.dataframe

from holoviews import opts
from holoviews.streams import Pipe, Buffer

hv.extension('bokeh')

In [None]:
hv.Image(data)

In [None]:
pipe = Pipe(data=[])
vector_dmap = hv.DynamicMap(hv.Image, streams=[pipe])
vector_dmap.opts(xlim=(-0.5, 0.5), ylim=(-0.5, 0.5))

In [None]:
x,y  = np.mgrid[-10:11,-10:11] * 0.1
sine_rings  = np.sin(x**2+y**2)*np.pi+np.pi
exp_falloff = 1/np.exp((x**2+y**2)/8)

for i in range(10):
    time.sleep(0.1)
    data = swarm.walkers.env_states.observs[i, :-3].reshape((210, 160, 3))
    pipe.send(data)

In [None]:
import numpy as np
from plangym import ParallelEnvironment
from plangym.minimal_montezuma import Montezuma
#from plangym.montezuma import Montezuma

from fragile.core.env import DiscreteEnv
from fragile.core.dt_sampler import GaussianDt
from fragile.core.models import RandomDiscrete
from fragile.core.states import States
from fragile.core.swarm import Swarm
from fragile.core.walkers import Walkers
from fragile.core.tree import HistoryTree
from fragile.atari.walkers import MontezumaWalkers
from fragile.atari.critics import MontezumaGrid
env = ParallelEnvironment(
        env_class=Montezuma,
        name=None,
        autoreset=True,
        blocking=False,
    episodic_live=True,
    min_dt=1,
    )
dt = GaussianDt(min_dt=3, max_dt=1000, loc_dt=4, scale_dt=2)

swarm = Swarm(
    model=lambda x: RandomDiscrete(x, dt_sampler=dt),
    walkers=MontezumaWalkers,
    env=lambda: DiscreteEnv(env),
    n_walkers=50,
    max_iters=25,
    prune_tree=True,
    reward_scale=2,
    dist_scale=4,
    tree=HistoryTree,
    critic=MontezumaGrid(),
    
)



In [None]:
from holoviews.streams import Pipe, Buffer
from streamz.dataframe import DataFrame
from streamz import Stream
import holoviews as hv
import hvplot.pandas
import hvplot.streamz
import numpy as np
import pandas as pd
hv.extension("bokeh")

In [None]:
mm = Montezuma()

In [None]:
o = mm.env.reset()

In [None]:
o

In [None]:
plt.imshow(o[50:, :, 0])

In [None]:
from PIL import Image
def resize_frame(frame: np.ndarray, height: int, width: int, mode="RGB") -> np.ndarray:
    """
    Use PIL to resize an RGB frame to an specified height and width.

    Args:
        frame: Target numpy array representing the image that will be resized.
        height: Height of the resized image.
        width: Width of the resized image.

    Returns:
        The resized frame that matches the provided width and height.
    """
    frame = Image.fromarray(frame)
    frame = frame.convert(mode).resize((height, width))
    return np.array(frame)

In [None]:
#def plot_memory(raw_observ, )
background = o[50:, :, ].mean(axis=2).astype(bool)
peste = resize_frame(df.values.T[::1, ::1], 160, 160, "L")
peste = peste / peste.max()# * 255
hv.Image(background)*hv.Image(peste).opts(alpha=0.7)

In [None]:
background.astype(int).tostring()

In [None]:
(hv.Image(o).opts(norm={'axiswise': True})+hv.Raster(df.values.T[::1, ::1]).opts(alpha=0.5, normalize=True)).opts(norm={'axiswise': True})

In [None]:
df = swarm.critic.buffer_df

In [None]:
df.hvplot(kind="heatmap")

In [None]:
swarm.critic.plot_grid()

In [None]:
df = pd.DataFrame(swarm.critic.memory[:, :, 0],
                                      columns=swarm.critic._cols, index=swarm.critic._index)

In [None]:
_ = swarm.run_swarm(print_every=10)

In [None]:
x = np.zeros((5,5))
x[[1,2], [0, 2]] = x[[1,2], [0, 2]] + 1
x

In [None]:
from plangym.montezuma import MyMontezuma

In [None]:
mm = MyMontezuma()

In [None]:
o = mm.env.reset()

In [None]:
mm.get_face_pixels(o)

In [None]:
np.where(o[:, :, 0] == 228)

In [None]:
face_pixels = [(y, x * mm.x_repeat) for y, x in mm.get_face_pixels(o)]
np.mean(face_pixels,axis=0)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(o[50:, :, 0] == 228)

In [None]:
o[50:, :, 0].shape

In [None]:
x = swarm.walkers.env_states.observs[:, -3]
y = swarm.walkers.env_states.observs[:, -2]
rooms = swarm.walkers.env_states.observs[:, -1]
rooms

In [None]:
grid = np.arange(320*160).reshape((320, 160, 1))

In [None]:
grid[x.astype(int), y.astype(int), rooms.astype(int)]

In [None]:
np.mean([78, 80, 81, 78, 79, 80, 81, 79, 80, 81])

In [None]:
for i in range(3):
    mm.step(mm.env.action_space.sample())
    mm.render()
    print(mm.pos)

In [None]:
old = {1,2,3,4,5}
new = {1,3,5}

In [None]:
old-new

In [None]:
env.pos

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
plt.imshow(obs[25:185,:,])

In [None]:
15 *15

In [None]:
from PIL import Image
def resize_frame(frame: np.ndarray, height: int, width: int, mode="RGB") -> np.ndarray:
    """
    Use PIL to resize an RGB frame to an specified height and width.

    Args:
        frame: Target numpy array representing the image that will be resized.
        height: Height of the resized image.
        width: Width of the resized image.

    Returns:
        The resized frame that matches the provided width and height.
    """
    frame = Image.fromarray(frame)
    frame = frame.convert(mode).resize((height, width))
    return np.array(frame)

In [None]:
plt.imshow(resize_frame(obs[25:185,:], width=45, height=45, mode="L"))

In [None]:
swarm.run_swarm?

In [None]:
state, obs = env.reset()

states = [state.copy() for _ in range(10)]
actions = [env.action_space.sample() for _ in range(10)]

data = env.step_batch(states=states, actions=actions)
new_states, observs, rewards, ends, infos = data