In [None]:
!pip install -r battlesnake_gym/requirements.txt
!pip install ipywidgets
!jupyter nbextension enable --py widgetsnbextension

In [None]:
from io import BytesIO
import time

import boto3
import PIL.Image
import sagemaker
import gym
from gym import wrappers
import numpy as np
import mxnet as mx
import matplotlib.pyplot as plt
%matplotlib inline
from importlib import reload
from IPython import display
import ipywidgets as widgets
from IPython.display import display as i_display

from battlesnake_gym.battlesnake_gym.snake_gym import BattlesnakeGym
from heuristics_utils import simulate

# Define the openAI gym
Optionally, you can define the initial game state (the situation simulator) of the snakes and food.
To use the initial state, set `USE_INITIAL_STATE = True` and enter the desired coordinates of the snake and food using the initial_state dictionary. The dictionary follows the same format as the battlesnake API.

In [None]:
USE_INITIAL_STATE = False

# Sample initial state for the situation simulator
initial_state = {
    "turn": 4,
    "board": {
        "height": 11,
        "width": 11,
        "food": [
            {
                "x": 1,
                "y": 3
            }
        ],
        "snakes": [{
                        "health": 90,
                        "body": [{"x": 8, "y": 5}],
                    }, 
                    {
                        "health": 90,
                        "body": [{"x": 1, "y": 6}],
                    },
                    {
                        "health": 90,
                        "body": [{"x": 3, "y": 3}],
                    },
                    {
                        "health": 90,
                        "body": [{"x": 6, "y": 4}],
                    },
                  ]

    }
}

if USE_INITIAL_STATE == False:
    initial_state = None

The parameters here must match the ones provided during training (except initial_state)

In [None]:
map_size = (11, 11)
number_of_snakes = 4
env = BattlesnakeGym(map_size=map_size, number_of_snakes=number_of_snakes, observation_type="bordered-51s", 
                     initial_game_state=initial_state)

# Load the trained model
The pretrained_models are loaded into an MXNet net. _You can safely ignore the __WARNING about the type for data0__._

In [None]:
params_name = "pretrained_models/Model-{}x{}/local-0000.params".format(map_size[0], map_size[1])
symbol_name = "pretrained_models/Model-{}x{}/local-symbol.json".format(map_size[0], map_size[1])

ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()
net = mx.gluon.SymbolBlock.imports(symbol_name, ['data0', 'data1', 'data2', 'data3'],
                                   params_name, ctx=ctx)
net.hybridize(static_alloc=True, static_shape=True)

# Simulation loop

Run a simulation with the environment with the heuritics that you wrote. 
To edit the heuristics, edit the file `LocalEnv/battlesnake_inference/battlesnake_heuristics`.
Note that you can track the progress of your work with git.

If you want to change the Snake behavior, update MyBattlesnakeHeuristics (source in LocalEnv/battlesnake_inference/battlesnake_heuristics.py) before running the next cell.

In [None]:
import battlesnake_inference.battlesnake_heuristics
reload(battlesnake_inference.battlesnake_heuristics)
from battlesnake_inference.battlesnake_heuristics import MyBattlesnakeHeuristics

heuristics = MyBattlesnakeHeuristics()
infos, rgb_arrays, actions, heuristics_remarks, json_array = simulate(env, net, heuristics, number_of_snakes)

# Playback the simulation

Defines the user interface of the simulator.

In [None]:
def get_env_json():
    if slider.value < len(json_array):
        return json_array[slider.value]
    else:
        return ""
    
def play_simulation(_):
    for i in range(slider.value, len(rgb_arrays) - slider.value - 1):
        slider.value = slider.value + 1
        display_image(slider.value)
        time.sleep(0.2)

def on_left_button_pressed(_):
    if slider.value > 0:
        slider.value = slider.value - 1 
    display_image(slider.value)

def on_right_button_pressed(_):
    if slider.value < len(rgb_arrays):
        slider.value = slider.value + 1 
    display_image(slider.value)
        
def display_image(index):  
    if index >= len(rgb_arrays):
        return
    info = infos[index]
    action = actions[index]
    heuristics = heuristics_remarks[index]
    snake_colours = env.snakes.get_snake_colours()
        
    line_0 = [widgets.Label("Turn count".format(info["current_turn"])),
                 widgets.Label("Snake")]
    
    line_1 = [widgets.Label(""), widgets.Label("Health")]
    
    line_2 = [widgets.Label("{}".format(info["current_turn"])), 
              widgets.Label("Action")]
    
    line_3 = [widgets.Label(""), widgets.Label("Gym remarks")]
    
    line_4 = [widgets.Label(""), widgets.Label("Heur. remarks")]

    action_convertion_dict = {0: "Up", 1: "Down", 2: "Left", 3: "Right", 4: "None"}
    for snake_id in range(len(action)):
        snake_health = "{}".format(info["snake_health"][snake_id])
        snake_health_widget = widgets.Label(snake_health)
        snake_action = "{}".format(action_convertion_dict[action[snake_id]])
        snake_action_widget = widgets.Label(snake_action)

        snake_colour = snake_colours[snake_id]
        hex_colour = '#%02x%02x%02x' % (snake_colour[0], snake_colour[1], snake_colour[2])
        snake_colour_widget = widgets.HTML(value = f"<b><font color="+hex_colour+">⬤</b>")

        gym_remarks = ""
        if snake_id in info["snake_info"]:
            if info["snake_info"][snake_id] != "Did not colide": 
                gym_remarks = "{}".format(info["snake_info"][snake_id])
        gym_remarks_widget = widgets.Label(gym_remarks)
        
        heuris_remarks = "{}".format(heuristics[snake_id])
        heuris_remarks_widget = widgets.Label(heuris_remarks)

        line_0.append(snake_colour_widget)
        line_1.append(snake_health_widget)
        line_2.append(snake_action_widget)
        line_3.append(gym_remarks_widget)
        line_4.append(heuris_remarks_widget)

    line_0_widget = widgets.VBox(line_0)
    line_1_widget= widgets.VBox(line_1)
    line_2_widget = widgets.VBox(line_2)
    line_3_widget = widgets.VBox(line_3)
    line_4_widget = widgets.VBox(line_4)
   
    info_widget = widgets.HBox([line_0_widget, line_1_widget, line_2_widget, line_3_widget, line_4_widget])
        
    image = PIL.Image.fromarray(rgb_arrays[index])
    f = BytesIO()
    image.save(f, "png")
    
    states_widget = widgets.Image(value=f.getvalue(), width=500)
    main_widgets_list = [states_widget, info_widget]
    
    main_widget = widgets.HBox(main_widgets_list)
    
    display.clear_output(wait=True)
    i_display(navigator)
    i_display(main_widget)
    
left_button = widgets.Button(description='◄')
left_button.on_click(on_left_button_pressed)
right_button = widgets.Button(description='►')
right_button.on_click(on_right_button_pressed)
slider = widgets.IntSlider(max=len(rgb_arrays) - 1)
play_button = widgets.Button(description='Play')
play_button.on_click(play_simulation)

navigator = widgets.HBox([left_button, right_button, slider, play_button])
display_image(index=0)

To get a JSON representation of the gym (environment), run the following function. You can also use output of the following function as an initial_state of the gym.

*Please provide this json array if you are reporting bugs in the gym*

In [None]:
get_env_json()

# Deploy the SageMaker endpoint
This section will deploy your new heuristics into the SageMaker endpoint

In [None]:
sage_session = sagemaker.session.Session()
s3_bucket = sage_session.default_bucket()
role = sagemaker.get_execution_role()
print("Your sagemaker s3_bucket is s3://{}".format(s3_bucket))

## (Optional) Run if you retrained the model
If you retrained your model in SagemakerModelTraining.ipynb but you did not create a new endpoint, please run the following cell to update the models.

In [None]:
!mv pretrained_models Models
!tar -czf Models.tar.gz Models
!mv Models pretrained_models

s3_client = boto3.client('s3')
s3_client.upload_file("Models.tar.gz", s3_bucket, 
                      "battlesnake-aws/pretrainedmodels/Models.tar.gz")
!rm Models.tar.gz

## Deploy your new heuristics
Using the new heuristics you developed, a new SageMaker endpoint will be created.

Firstly, delete the old endpoint, model and endpoint config.

In [None]:
sm_client = boto3.client(service_name='sagemaker')
sm_client.delete_endpoint(EndpointName='battlesnake-endpoint')
sm_client.delete_endpoint_config(EndpointConfigName='battlesnake-endpoint')
sm_client.delete_model(ModelName="battlesnake-mxnet")

Run the following cells to create a new model and endpoint with the new heuristics

In [None]:
target_key = "battlesnake-aws/pretrainedmodels/Models.tar.gz"

model_data = "s3://{}/{}".format(s3_bucket, target_key)
endpoint_instance_type = "ml.m5.xlarge"

from sagemaker.mxnet import MXNetModel
mxnet_model = MXNetModel(model_data=model_data,
                             entry_point='predict.py',
                             role=role,
                             framework_version='1.6.0',
                             source_dir='battlesnake_inference',
                             name="battlesnake-mxnet",
                             py_version='py3')
predictor = mxnet_model.deploy(initial_instance_count=1,
                               instance_type=endpoint_instance_type,
                               endpoint_name='battlesnake-endpoint')

## Testing the new endpoint
You should see `Action to take is X`

In [None]:
data1 = np.zeros(shape=(1, 2, 3, map_size[0]+2, map_size[1]+2))
data2 = np.zeros(shape=(1, 2))
data3 = np.zeros(shape=(1, 2))
data4 = np.zeros(shape=(1, 2))
health_dict = {0: 50, 1: 50}
json = {"board": {
            "height": 15,
            "width": 15,
            "food": [],
            "snakes": []
            },
        "you": {
            "id": "snake-id-string",
            "name": "Sneky Snek",
            "health": 90,
            "body": [{"x": 1, "y": 3}]
            }
        }
action = predictor.predict({"state": data1, "snake_id": data2, 
                           "turn_count": data3, "health": data4,  
                           "all_health": health_dict, "map_width": map_size[0], "json": json})
print("Action to take is {}".format(action))