In [None]:
import numpy as np
from plangym import AtariEnvironment, ParallelEnvironment

from fragile.core.dt_sampler import GaussianDt
from fragile.core.env import DiscreteEnv
from fragile.core.models import RandomDiscrete
from fragile.core.states import States
from fragile.core.swarm import Swarm
from fragile.core.walkers import Walkers
from fragile.atari.env import AtariEnv
from fragile.atari.walkers import AtariWalkers
from fragile.experimental.walkers import MetricWalkers
from fragile.core.tree import HistoryTree

env = ParallelEnvironment(
        env_class=AtariEnvironment,
        name="MsPacman-ram-v0",
        clone_seeds=True,
        autoreset=True,
        blocking=False,
        episodic_live=True,
        min_dt=2,
    )
dt = GaussianDt(min_dt=2, max_dt=1000, loc_dt=3, scale_dt=2)
swarm = Swarm(
        model=lambda x: RandomDiscrete(x, dt_sampler=dt),
        walkers=lambda **kwargs: MetricWalkers.from_walkers_class(Walkers, **kwargs),
        env=lambda: AtariEnv(env),
        n_walkers=40,
        max_iters=100,
        prune_tree=False,
        reward_scale=2,
        minimize=False,
        tree=HistoryTree,
        use_tree=True,
        plot_interval=50,

    )


swarm.walkers.accumulate_rewards = True


In [None]:
swarm.walkers.plot_best_evolution()

In [None]:

_ = swarm.run_swarm(print_every=100)
#swarm.tree.data.remove_node(0)

In [None]:
import networkx as nx
pos = nx.nx_pydot.graphviz_layout(swarm.tree.data, prog='dot')

In [None]:
obs = obs = get_game_observs(swarm)

In [None]:
plot_g = get_plot_graph(swarm)

In [None]:
a.pop("cac")

In [None]:
import pickle 
with open('embeddings_tree_demo.pickle', 'wb') as handle:
    pickle.dump(embs, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
embs = create_embedding_layout(swarm)

In [None]:
import pickle 
with open('embeddings.pickle', 'wb') as handle:
    pickle.dump(embs, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
%%HTML
<style>
.container { width:100% !important; }
.input{ width:60% !important;
       align: center;
      }
.text_cell{ width:70% !important;
            font-size: 16px;}
.title {align:center !important;}
</style>



In [None]:
states, actions, n_iters, nodes, edges = get_best_path(swarm)

In [None]:
len(states), len(actions)

In [None]:
obs = get_game_observs(swarm)

In [None]:
import pickle 
with open('observs_tree_demo.pickle', 'wb') as handle:
    pickle.dump(obs, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
import pickle 
with open('pos_tree_demo.pickle', 'wb') as handle:
    pickle.dump(pos, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
plot_g = get_plot_graph(swarm)

import pickle 
with open('graph_tree_demo.pickle', 'wb') as handle:
    pickle.dump(plot_g, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
from plot_swarm import plot_iteration, get_plot_graph, get_game_observs, create_embedding_layout

In [None]:
import pickle 
with open("pos.pickle", "rb") as f:
    pos = pickle.load(f)

with open("observs.pickle", "rb") as f:
    obs = pickle.load(f)

with open("graph.pickle", "rb") as f:
    plot_g = pickle.load(f)

In [None]:
import holoviews as hv
hv.extension("bokeh")

In [None]:
def _plot_subgraph(plot_g, embs, bundle: bool=False):
    graph = hv.Graph.from_networkx(plot_g, embs)

    graph.opts(node_color=hv.dim('cum_reward') ** 1/2, node_cmap="viridis", node_size=hv.dim('last_size'),
               edge_line_width=hv.dim('final')*2,
               node_line_width=1.5,
               node_alpha=0.8,
               xaxis=None, yaxis=None,
               edge_alpha=hv.dim('final'), edge_line_color=hv.dim('final'), edge_cmap=["white", "red"],
               node_line_color="red", node_line_alpha=hv.dim('last_line_alpha'),
               width=800, height=600, bgcolor='gray', colorbar=True, title="Sampling MsPacman-ram using a FractalAI Swarm")
    if bundle:
        bundled = bundle_graph(graph)
        return bundled.opts(norm=dict(framewise=True))
    return graph.opts(norm=dict(framewise=True))

In [None]:
def plot_subgraph_embs(plot_g, embs, bundle: bool=False):
    graph = hv.Graph.from_networkx(plot_g, embs)

    graph.opts(node_color=hv.dim('cum_reward'), node_cmap="viridis", node_size=hv.dim('last_size'),
               edge_line_width=hv.dim('final'),
               node_line_width=1.5,
               node_alpha=0.8,
               xaxis=None, yaxis=None,
               edge_alpha=hv.dim('final')*2, edge_line_color=hv.dim('final'), edge_cmap=["white", "red"],
               node_line_color="red", node_line_alpha=hv.dim('last_line_alpha'),
               width=800, height=600, bgcolor='gray', colorbar=True)
    if bundle:
        bundled = bundle_graph(graph)
        return bundled.opts(norm=dict(framewise=True))
    return graph.opts(norm=dict(framewise=True))

In [None]:
import numpy as np
for n in plot_g.nodes:
    if "reward" in plot_g.nodes[n]:
        plot_g.nodes[n]["reward"] = np.log(plot_g.nodes[n]["reward"] + 1)
        #plot_g.nodes[n]["cum_reward"] = np.log(plot_g.nodes[n]["cum_reward"] + 1)

In [None]:
embs = create_embedding_layout(swarm)

In [None]:
from functools import partial
_plot_func = partial(plot_iteration, graph=plot_g, embeddings=embs, observs=obs, plot_func=plot_subgraph_embs)
from matplotlib.cm import viridis
def plot_func(iteration=1, memory=1, bundle=False):
    return _plot_func(iteration, start=max(0, iteration - memory), bundle=bundle).opts(shared_axes=False)

# When run live, this cell's output should match the behavior of the GIF below
dmap = hv.DynamicMap(plot_func, kdims=['iteration', "memory"])
dmap.redim.range(iteration=(1,99), memory=(99,1))# * dynspread(datashade(plot_func(200, 200).nodes,cmap=viridis))

In [None]:
from holoviews.operation.datashader import datashade, dynspread
dynspread(datashade(plot_func(100, 100).nodes,cmap=["cyan"]))

In [None]:
obs.keys()

In [None]:
for i in range(1, 100):
    plot = plot_func(i, 1)
    hv.save(plot, 'swarm_points/walkers{0:04d}.png'.format(i))

In [None]:
from PIL import Image
import os
def assemble_gif(name: str, img_dir="tree_s"):
    
    dirs = os.listdir(img_dir)
    dirs.sort(key=lambda v: int(v.split(".")[0].split("s")[-1]))  # sort numerically
    dirs = [os.path.join(img_dir, d) for d in dirs]
    im1 = Image.open(dirs[0])
    gif_imgs = [Image.open(d) for d in dirs[1:]]
    im1.save(
        "%s.gif" % name,
        save_all=True,
        append_images=gif_imgs,
        loop=0,
        duration=100,
    )

assemble_gif("swarm_short", img_dir="swarm_points")

In [None]:
g, new_embs = create_subgraph(15, 20, plot_g, embs)

In [None]:
plot_g = get_plot_graph(swarm)

graph = plot_graph(plot_g, embs)

In [None]:
%%file plot_swarm.py
import numpy as np
import hvplot
import hvplot.pandas
import hvplot.networkx as hvnx
import networkx as nx
import holoviews as hv
from fragile.core.utils import resize_frame
from umap import UMAP
import copy
import warnings
from holoviews.operation.datashader import datashade, bundle_graph
warnings.filterwarnings("ignore")
hv.extension("bokeh")

def get_path_nodes_and_edges(g, leaf_name):
    parent = -100
    nodes = [int(leaf_name)]
    edges = []
    while parent != 0:
        parents = list(g.in_edges([leaf_name]))
        try:
            parent = parents[0][0]
            nodes.append(parent)
            edges.append(tuple([parent, leaf_name]))
            leaf_name = int(parent)
        except:
            print(parent, leaf_name)
            return nodes, edges
    return nodes, edges

def get_best_path(swarm):
    best_ix = swarm.walkers.states.cum_rewards.argmax()
    best = swarm.walkers.states.id_walkers[best_ix]
    leaf_name = swarm.tree.node_names[best]
    nodes, edges = get_path_nodes_and_edges(swarm.tree.data, leaf_name)
    nodes, edges = list(reversed(nodes))[1:], list(reversed(edges))[1:]
    states = [swarm.tree.data.nodes[n]["state"] for n in nodes]
    n_iters = [swarm.tree.data.nodes[n]["n_iter"] for n in nodes]
    actions = [swarm.tree.data.edges[e]["action"] for e in edges]
    return states, actions, n_iters, nodes, edges

def add_image_from_node(swarm, node_id):
    parents = list(swarm.tree.data.in_edges([node_id]))
    if len(parents) > 0:
        parent = parents[0][0]
        action = swarm.tree.data.edges[(parent, node_id)]["action"]
        state = swarm.tree.data.nodes[parent]["state"]
        data = swarm.env._env.step(state=state, action=action)
        obs = swarm.env._env.unwrapped.ale.getScreenRGB()
        obs = resize_frame(obs[:, : , 0][2:170], 60, 60, "L")
        return obs


def create_embedding_layout(swarm):
    nodes = list(swarm.tree.data.nodes())[1:]
    observs = np.array([add_image_from_node(swarm, n) for n in nodes])
    samples = observs.reshape(observs.shape[0], -1)
    embeddings = UMAP(n_components=2,
                      min_dist=0.99,
                      n_neighbors=50).fit_transform(samples)
    return {n : embeddings[i] for i, n in enumerate(nodes)}

def get_plot_graph(swarm):
    plot_g = nx.Graph()
    states, actions, n_iters, nodes, edges = get_best_path(swarm)
    for n in swarm.tree.data.nodes():
        is_best = n in nodes
        node_attrs = copy.deepcopy(swarm.tree.data.nodes[n])
        node_attrs.pop("state")
        plot_g.add_node(n,
                        final=1 if is_best else 0.3,
                        node_alpha=1 if is_best else 0.2, 
                        line_alpha = 1 if is_best else 0.0,
                        **node_attrs,
                       )
    for a, b in swarm.tree.data.edges():
        plot_g.add_edge(a, b, weight=float(swarm.tree.data.edges[(a,b)]["action"]),
                        final=1 if (a,b) in edges else 0.3)
    return plot_g


def plot_graph(plot_g, embs):
    graph = hv.Graph.from_networkx(plot_g, embs)

    graph.opts(node_color=hv.dim('n_iter'), node_cmap="viridis", node_size=3,#hv.dim('final') * 5,
               edge_line_width=hv.dim('final') * 0.2,
               node_line_width=0.5,
               node_alpha=hv.dim('node_alpha'),
               edge_alpha=hv.dim('final'), edge_line_color=hv.dim('final'), edge_cmap=["white", "red"],
               node_line_color="red", node_line_alpha=hv.dim('line_alpha'),
               width=800, height=600, bgcolor='gray', colorbar=True)
    return graph

def create_subgraph(start, end, graph, embs=None, key="n_iter"):
    embs = embs if embs is not None else {}
    g = nx.Graph()
    for n in graph.nodes:
        n_iter = graph.nodes[n][key]
        if start <= n_iter <= end:
            g.add_node(n, **graph.nodes[n])
            g.nodes[n]["last_line_alpha"] = 1 if n_iter == end else 0
            g.nodes[n]["last_size"] = 8 if n_iter == end else 4
    for a, b in graph.edges:
        n_iter_a = graph.nodes[a][key]
        n_iter_b = graph.nodes[b][key]
        if start <= n_iter_a <= end and start <= n_iter_b <= end:
            g.add_edge(a, b, **graph.edges[(a, b)])
    new_embs = {k:v for k,v in embs.items() if k in g.nodes}
    return g, new_embs

def plot_subgraph(plot_g, embs, bundle: bool=False):
    graph = hv.Graph.from_networkx(plot_g, embs)

    graph.opts(node_color=hv.dim('cum_reward'), node_cmap="viridis", node_size=hv.dim('last_size'),
               edge_line_width=hv.dim('final'),
               node_line_width=1.5,
               node_alpha=0.8,
               xaxis=None, yaxis=None,
               edge_alpha=hv.dim('final'), edge_line_color=hv.dim('final'), edge_cmap=["white", "red"],
               node_line_color="red", node_line_alpha=hv.dim('last_line_alpha'),
               width=800, height=600, bgcolor='gray', colorbar=True)
    if bundle:
        bundled = bundle_graph(graph)
        return bundled.opts(norm=dict(framewise=True))
    return graph.opts(norm=dict(framewise=True))

def plot_iteration(iteration, graph, embeddings, start=0, key="n_iter",
                   bundle=False, observs=None, plot_func=plot_subgraph):
    g, new_embs = create_subgraph(start=start, end=iteration,
                                  graph=graph, embs=embeddings, key=key)
    graph = plot_func(g, new_embs, bundle=bundle)
    if observs is not None:
        screen = observs.get(iteration)
        it = iteration
        while screen is None:
            it -= 1
            screen = observs.get(it)
        image = hv.RGB(screen).opts(xaxis=None, yaxis=None,
                                    normalize=True, shared_axes=False)
        return image + graph
    return graph

def get_game_observs(swarm):
    states, actions, n_iters, nodes, edges = get_best_path(swarm)
    observs = {}
    for node_id, it in zip(nodes, n_iters):
        parents = list(swarm.tree.data.in_edges([node_id]))
        if len(parents) > 0:
            parent = parents[0][0]
            action = swarm.tree.data.edges[(parent, node_id)]["action"]
            state = swarm.tree.data.nodes[parent]["state"]
            data = swarm.env._env.step(state=state, action=action)
            obs = swarm.env._env.unwrapped.ale.getScreenRGB()
            observs[it] = obs
    return observs

In [None]:
graph

In [None]:
 pos  = nx.nx_pydot.graphviz_layout(swarm.tree.data, prog='dot')

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 18))
nx.draw(swarm.tree.data, pos=embs, node_size=1, alpha=0.2)

In [None]:
swarm.tree.data.edges[(0, 2)]

In [None]:
pos

In [None]:
G = nx.Graph()

G.add_edge('a', 'b', weight=0.6)
G.add_edge('a', 'c', weight=0.2)
G.add_edge('c', 'd', weight=0.1)
G.add_edge('c', 'e', weight=0.7)
G.add_edge('c', 'f', weight=0.9)
G.add_edge('a', 'd', weight=0.3)

G.add_node('a', size=20)
G.add_node('b', size=10)
G.add_node('c', size=12)
G.add_node('d', size=5)
G.add_node('e', size=8)
G.add_node('f', size=3)

pos = nx.spring_layout(G)  # positions for all nodes

hvnx.draw(G, pos, edge_color='weight', edge_cmap='viridis',
          edge_width=hv.dim('weight')*10, node_size=hv.dim('size')*20)

In [None]:
hvnx.draw(plot_g, pos=embs, node_color=hv.dim('size'), node_cmap='viridis',
          edge_width=hv.dim('weight') * 0.01, alpha=0.7, node_line_color=None, node_size=15, widht=800, height=600)

In [None]:
nx.is_tree(plot_g)

In [None]:
plot_g = nx.Graph()
for n in swarm.tree.data.nodes():
    plot_g.add_node(n, size=float(swarm.tree.data.nodes[n].get("n_iter", -100)))
for a, b in swarm.tree.data.edges():
    plot_g.add_edge(a, b, weight=float(swarm.tree.data.edges[(a,b)]["action"]))

In [None]:
swarm.tree.data.nodes[6]

In [None]:
from holoviews.streams import Pipe, Buffer
from streamz.dataframe import DataFrame
from streamz import Stream
import holoviews as hv
import hvplot.pandas
import hvplot.streamz
import pandas as pd
hv.extension("bokeh")

In [None]:
df = pd.DataFrame(np.arange(25).reshape((5,5)))

In [None]:
df.hvplot.heatmap()

In [None]:
import pandas as pd
pd.read_csv("metrics.csv")

In [None]:
state, obs = env.reset()

states = [state.copy() for _ in range(10)]
actions = [env.action_space.sample() for _ in range(10)]

data = env.step_batch(states=states, actions=actions)
new_states, observs, rewards, ends, infos = data