In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os

import numpy as np
import torch
import time

import pandas as pd
from carle.env import CARLE
from carle.mcl import RND2D, AE2D, SpeedDetector, PufferDetector, CornerBonus
from game_of_carle.agents.grnn import ConvGRNN
from game_of_carle.agents.carla import CARLA
from game_of_carle.agents.harli import HARLI
from game_of_carle.agents.toggle import Toggle, BilateralToggle
from game_of_carle.algos.cma import CMAPopulation

import bokeh
import bokeh.io as bio
from bokeh.io import output_notebook, show
from bokeh.plotting import figure

from bokeh.layouts import column, row
from bokeh.models import TextInput, Button, Paragraph
from bokeh.models import ColumnDataSource

from bokeh.events import DoubleTap, Tap

output_notebook()

In [None]:
# initialize environment and add exploratin wrappers

env = CARLE(instances=2, device="cpu", height=128, width=128) 

my_path = "../policies/interactive_evolution_face_"

agent = CMAPopulation(BilateralToggle, device="cpu", save_path=my_path, lr=.1, population_size=16)

my_rules = "B3/S345678"
env.rules_from_string(my_rules)

In [None]:

#choose device
if (1):
    my_device = torch.device("cpu")
else:
    my_device = torch.device("cuda")

def modify_doc(doc):
        
    with torch.no_grad():
        
        global agent
        
        global obs
        obs = env.reset()
        p0 = figure(plot_width=3*256, plot_height=3*256)
        p1 = figure(plot_width=3*256, plot_height=3*256)
        
        global my_period
        my_period = 512

        # add a circle renderer with x and y coordinates, size, color, and alpha

        source_0 = ColumnDataSource(data=dict(my_image_0=[obs[0].squeeze().cpu().numpy()]))
        source_1 = ColumnDataSource(data=dict(my_image_1=[obs[1].squeeze().cpu().numpy()]))
                                            
        img_0 = p0.image(image='my_image_0',x=0, y=0, dw=256, dh=256, palette="Magma256", source=source_0)
        img_1 = p1.image(image='my_image_1',x=0, y=0, dw=256, dh=256, palette="Magma256", source=source_1)

        button_go = Button(sizing_mode="stretch_width", label="Run >")     
        button_slower = Button(sizing_mode="stretch_width",label="<< Slower")
        button_faster = Button(sizing_mode="stretch_width",label="Faster >>")

        input_birth = TextInput(value=f"{env.birth}")
        input_survive = TextInput(value=f"{env.survive}")
        
        button_birth = Button(sizing_mode="stretch_width", label="Update Birth Rules")
        button_survive = Button(sizing_mode="stretch_width", label="Update Survive Rules")
        
        button_pick_0 = Button(sizing_mode="stretch_width", label="Select agent a")
        button_pick_1 = Button(sizing_mode="stretch_width", label="Select agent b")
        button_megapick_0 = Button(sizing_mode="stretch_width", label="Mega-select agent a (+10)")
        button_megapick_1 = Button(sizing_mode="stretch_width", label="Mega-select agent b (+10)")
        button_pick_none = Button(sizing_mode="stretch_width", label="Select neither")
        
        button_start_over = Button(sizing_mode="stretch_width", label="Start over (reset population)")
        
        message = Paragraph()
        
        def update():
            global obs
            
            global stretch_pixel
            
            global action_0
            global action_1
            global action
            
            global my_step
            
            action = torch.cat([action_0[0:1], action_1[0:1]], dim=0)
            obs, r, d, i = env.step(action)

            action_0 = agent(obs[0:1], agent_index=0)
            action_1 = agent(obs[1:2], agent_index=1)
                        
            padded_action_0 = stretch_pixel/2 + env.action_padding(action_0).squeeze()
            padded_action_1 = stretch_pixel/2 + env.action_padding(action_1).squeeze()
            
            my_img_0 = (padded_action_0*2 + obs[0:1].squeeze()).cpu().numpy()
            my_img_1 = (padded_action_1*2 + obs[1:2].squeeze()).cpu().numpy()
            
            my_img_0[my_img_0 > 3.0] = 3.0
            my_img_1[my_img_1 > 3.0] = 3.0
            
            new_data_0 = dict(my_image_0=[my_img_0])
            new_data_1 = dict(my_image_1=[my_img_1])

            source_0.stream(new_data_0, rollover=1)
            source_1.stream(new_data_1, rollover=1)
            
            my_step += 1
            message.text = f"step {my_step}"

        def go():

            if button_go.label == "Run >":
                my_callback = doc.add_periodic_callback(update, my_period)
                button_go.label = "Pause"
                #doc.remove_periodic_callback(my_callback)

            else:
                doc.remove_periodic_callback(doc.session_callbacks[0])
                button_go.label = "Run >"

        def faster():
            global my_period
            my_period = max([my_period / 2, 1])
            go()
            go()

        def slower():

            global my_period
            my_period = min([my_period * 2, 8192])
            go()
            go()
                             
        def reset():
            global obs
            
            global stretch_pixel
            
            global action_0
            global action_1
            global action
            
            global my_step

            obs = env.reset()
            
            action_0 = agent(obs[0:1], agent_index=0)
            action_1 = agent(obs[1:2], agent_index=1)
            
            stretch_pixel = torch.zeros_like(obs[0:1]).squeeze()
            stretch_pixel[0,0] = 3
                        
            new_data_0 = dict(my_image_0=[(stretch_pixel + obs[0:1].squeeze()).cpu().numpy()])
            new_data_1 = dict(my_image_1=[(stretch_pixel + obs[1:2].squeeze()).cpu().numpy()])
            
            source_0.stream(new_data_0, rollover=8)
            source_1.stream(new_data_1, rollover=8)
            
            my_step = 0
                             

        def pick_agent_0():
            global obs
            
            global stretch_pixel
            
            global action_0
            global action_1
            global action
            
            global my_step

            agent.step(rewards=[1.,0.])
            
            obs = env.reset()
            
            action_0 = agent(obs[0:1], agent_index=0)
            action_1 = agent(obs[1:2], agent_index=1)
            
            stretch_pixel = torch.zeros_like(obs[0:1]).squeeze()
            stretch_pixel[0,0] = 3
                        
            new_data_0 = dict(my_image_0=[(stretch_pixel + obs[0:1].squeeze()).cpu().numpy()])
            new_data_1 = dict(my_image_1=[(stretch_pixel + obs[1:2].squeeze()).cpu().numpy()])
            
            source_0.stream(new_data_0, rollover=8)
            source_1.stream(new_data_1, rollover=8)
            
            my_step = 0                 
            
        def megapick_agent_0():
            
            global obs
            
            global stretch_pixel
            
            global action_0
            global action_1
            global action
            
            global my_step

            agent.step(rewards=[10.,0.])
            
            obs = env.reset()
            
            action_0 = agent(obs[0:1], agent_index=0)
            action_1 = agent(obs[1:2], agent_index=1)
            
            stretch_pixel = torch.zeros_like(obs[0:1]).squeeze()
            stretch_pixel[0,0] = 3
                        
            new_data_0 = dict(my_image_0=[(stretch_pixel + obs[0:1].squeeze()).cpu().numpy()])
            new_data_1 = dict(my_image_1=[(stretch_pixel + obs[1:2].squeeze()).cpu().numpy()])
            
            source_0.stream(new_data_0, rollover=8)
            source_1.stream(new_data_1, rollover=8)
            
            my_step = 0                 

        def pick_agent_1():
            
            global obs
            global stretch_pixel
            
            global action_0
            global action_1
            global action
            
            global my_step

            agent.step(rewards=[0.,1.])
            
            obs = env.reset()
            
            action_0 = agent(obs[0:1], agent_index=0)
            action_1 = agent(obs[1:2], agent_index=1)
            
            stretch_pixel = torch.zeros_like(obs[0:1]).squeeze()
            stretch_pixel[0,0] = 3
                        
            new_data_0 = dict(my_image_0=[(stretch_pixel + obs[0:1].squeeze()).cpu().numpy()])
            new_data_1 = dict(my_image_1=[(stretch_pixel + obs[1:2].squeeze()).cpu().numpy()])
            
            source_0.stream(new_data_0, rollover=8)
            source_1.stream(new_data_1, rollover=8)
            
            my_step = 0
            
        def megapick_agent_1():
            
            global obs
            global stretch_pixel
            
            global action_0
            global action_1
            global action
            
            global my_step

            agent.step(rewards=[0.,10.])
            
            obs = env.reset()
            
            action_0 = agent(obs[0:1], agent_index=0)
            action_1 = agent(obs[1:2], agent_index=1)
            
            stretch_pixel = torch.zeros_like(obs[0:1]).squeeze()
            stretch_pixel[0,0] = 3
                        
            new_data_0 = dict(my_image_0=[(stretch_pixel + obs[0:1].squeeze()).cpu().numpy()])
            new_data_1 = dict(my_image_1=[(stretch_pixel + obs[1:2].squeeze()).cpu().numpy()])
            
            source_0.stream(new_data_0, rollover=8)
            source_1.stream(new_data_1, rollover=8)
            
            my_step = 0
            
        def pick_none():
            global obs
            
            global stretch_pixel
            
            global action_0
            global action_1
            global action
            
            global my_step

            agent.step(rewards=[-1.,-1.])
            
            obs = env.reset()
            
            action_0 = agent(obs[0:1], agent_index=0)
            action_1 = agent(obs[1:2], agent_index=1)
            
            stretch_pixel = torch.zeros_like(obs[0:1]).squeeze()
            stretch_pixel[0,0] = 3
                        
            new_data_0 = dict(my_image_0=[(stretch_pixel + obs[0:1].squeeze()).cpu().numpy()])
            new_data_1 = dict(my_image_1=[(stretch_pixel + obs[1:2].squeeze()).cpu().numpy()])
            
            source_0.stream(new_data_0, rollover=8)
            source_1.stream(new_data_1, rollover=8)
            
            my_step = 0

        def start_over():
            
            agent.start_over()
            reset()
            
        def set_birth_rules():
            env.birth_rule_from_string(input_birth.value)
            
            my_message = "Rules updated to B"

            for elem in env.birth:
                my_message += str(elem)
            my_message += "/S"    

            for elem in env.survive:
                my_message += str(elem)

            message.text = my_message

            #reset()

        def set_survive_rules():
            env.survive_rule_from_string(input_survive.value)
            
            my_message = "Rules updated to B"

            for elem in env.birth:
                my_message += str(elem)
            my_message += "/S"    

            for elem in env.survive:
                my_message += str(elem)

            message.text = my_message

            #reset()
            
        def human_toggle(event):
            global action
            global action_0
            global action_1

            coords =  [np.round(env.height*event.y/256-0.5), np.round(env.width*event.x/256-0.5)]
            offset_x = (env.height - env.action_height) / 2
            offset_y = (env.width - env.action_width) / 2

            coords[0] = coords[0] - offset_x
            coords[1] = coords[1] - offset_y

            coords[0] = np.uint8(np.clip(coords[0], 0, env.action_height-1))
            coords[1] = np.uint8(np.clip(coords[1], 0, env.action_height-1))

            action[:, :, coords[0], coords[1]] = 1.0 * (not(action[0, :, coords[0], coords[1]]))

            padded_action_0 = stretch_pixel/2 + env.action_padding(action[0]).squeeze()
            padded_action_1 = stretch_pixel/2 + env.action_padding(action[1]).squeeze()

            my_img_0 = (padded_action_0*2 + obs[0].squeeze()).cpu().numpy()
            my_img_1 = (padded_action_1*2 + obs[1].squeeze()).cpu().numpy()
            
            my_img_0[my_img_0 > 3.0] = 3.0
            my_img_1[my_img_1 > 3.0] = 3.0
                        
            new_data_0 = dict(my_image_0=[my_img_0])
            new_data_1 = dict(my_image_1=[my_img_1])
            
            source_0.stream(new_data_0, rollover=8)
            source_1.stream(new_data_1, rollover=8)
            
            action_0 = action[0:1]
            action_1 = action[1:2]


        agent_on = True
                             
        global action_0
        global action_1
        global action
        
        action_0 = torch.zeros(1, 1, env.action_height, env.action_width).to(my_device)
        action_1 = torch.zeros(1, 1, env.action_height, env.action_width).to(my_device)
        action = torch.cat([action_0[0:1], action_1[0:1]], dim=0)

        button_birth.on_click(set_birth_rules)
        button_survive.on_click(set_survive_rules)
        button_go.on_click(go)
        button_faster.on_click(faster)
        button_slower.on_click(slower)
        
        p0.on_event(Tap, human_toggle)
        p1.on_event(Tap, human_toggle)

        button_pick_0.on_click(pick_agent_0)
        button_pick_1.on_click(pick_agent_1)
        button_megapick_0.on_click(megapick_agent_0)
        button_megapick_1.on_click(megapick_agent_1)
        button_pick_none.on_click(pick_none)
        button_start_over.on_click(start_over)
                             
        control_layout = row(button_slower, button_go, button_faster)
        rule_layout = row(input_birth, button_birth, input_survive, button_survive)
        
        pick_none_row = row(button_pick_none)
        start_over_row = row(button_start_over)
        pick_layout = row(button_pick_0, button_megapick_0, button_pick_1, button_megapick_1)
                             
        column0 = column(p0)
        column1 = column(p1)
        
        display_layout_0 = row(column0, column1)
        message_layout = row(message)

        doc.add_root(display_layout_0)
        doc.add_root(pick_layout)
        doc.add_root(pick_none_row)
        doc.add_root(start_over_row)
        doc.add_root(control_layout)
        doc.add_root(rule_layout)
        doc.add_root(message_layout)

        reset()

show(modify_doc)    