In [None]:
# UNCOMMENT FOR INTERACTIVE PLOTTING
# %matplotlib notebook
%matplotlib widget
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import animation, rc, cm
import time
from snakelib import FastSnake, show_gui, NeuralAgent

rc("animation", html="html5")

(ML:practical_work:genetic_snake)=
# Reinforcement learning on snake with a genetic neural network

:::{admonition} Required files
:class: important
In order to work properly, this notebook requires the following modules in its folder:

* {download}`snakelib.py <snakelib.py>`

Put it in your working directory along with this notebook. 

:::


This notebook is an example of supervised learning applied to video games. You will use the legendary game Snake rewritten in Python for the occasion and will try to develop an automatic game strategy. In a first step, by hand and in a second step using a genetic algorithm to evolve a neural network. Graphical examples will allow to see the evolution of the game performances.

## Part 1: Try the game

In this first part, you are asked to try the game and check that you understand the rules. Try to imagine what you need to know to win a game.

In [None]:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
snake = FastSnake(Nrow=10, Ncol=10)
show_gui(snake, ax)

## Automatic playing

At first, you are asked to build an agent that will play automatically. It will have to make decisions according to the snake's sensor values. These decisions are to choose the best direction to take.

In [None]:
def my_agent(sensors):
    """
    A simple agent demo.
    """
    turn_neighbors = sensors[:3]
    fcos, fsin = sensors[3:]
    direction_ids = np.array([1, 0, -1])
    # FRUIT IS NEAR ?
    for direction in np.arange(3):
        if turn_neighbors[direction] == 1:
            return direction_ids[direction]
    # OK, NO, THEN GET CLOSER TO IT ?
    if 0.0 in turn_neighbors:
        prefered_directions = np.zeros(3)
        if fcos > 0.0:
            prefered_directions[:] += 0, 3, 0
        elif fcos < 0.0:
            prefered_directions[:] += 1, 0, 1
        if fsin < 0.0:
            prefered_directions[:] += 0, 1, 1
        elif fsin > 0.0:
            prefered_directions[:] += 1, 1, 0
        p = prefered_directions * (turn_neighbors >= 0.0)
        my_choice = np.where(p == p.max())[0][0]
        return direction_ids[my_choice]
    else:
        # OK LET'S SUICIDE
        return np.random.randint(3)


snake = FastSnake(Nrow=10, Ncol=10)
my_choice = my_agent(snake.sensors())
my_choice

## Automatic play with graphic output


In [None]:
snake5 = FastSnake(Nrow=15, Ncol=15)


def updatefig(*args):
    sensors = snake5.sensors()
    my_choice = my_agent(sensors)
    snake5.turn(my_choice)
    im2.set_array(snake5.grid)
    if snake5.status != 0:
        snake5.reset()
    return (im2,)


fig2, ax2 = plt.subplots()
ax2.axis("off")
im2 = plt.imshow(snake5.grid, interpolation="nearest", animated=True)
anim = animation.FuncAnimation(fig2, updatefig, frames=40, interval=50, blit=True)
plt.show()  # UNCOMMENT TO PLAY
# plt.close() # COMMENT TO PLAY
# anim # COMMENT TO PLAY

In [None]:
anim.pause()
plt.close(fig2)

## Benchmark

In [None]:
import tqdm

Nagent_ids = 200
max_turns = 1000
snake3 = FastSnake(Nrow=10, Ncol=10)
scores = np.zeros(Nagent_ids)
turns = np.zeros(Nagent_ids)
# for agent_id in tqdm.trange(Nagent_ids):
for agent_id in range(Nagent_ids):
    snake3.reset()
    turn = 0
    while snake3.status == 0:
        sensors = snake3.sensors()
        my_choice = my_agent(sensors)
        snake3.turn(my_choice)
        turn += 1
        if turn >= max_turns:
            break
    scores[agent_id] = snake3.score
    turns[agent_id] = turn

In [None]:
data = pd.DataFrame({"score": scores, "turns": turns})
data.describe().loc[["mean", "std", "max", "min", "count"]].T

## Genetic Neural Network agent 

In [None]:
# NEURAL FUNCTIONS
def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))


def ReLu(x):
    return np.where(x > 0.0, x, 0.0)


def identity(x):
    return x


# ARG MAX
def argMax(x):
    return int(np.where(x == x.max())[0][0])


# GENETIC ALGORITHM SETUP

Npop = 100  # NUMBER OF INDIVIDUALS IN THE POPULATION
Ngen = 10  # NUMBER OF GENERATIONS OF EVOLUTION
Ntries = 2  # NUMBER OF TRIES PER INDIVIDUAL PER GENERATION
Net_struct = 5, 3  # NETWORK STRUCTURE
keep_ratio = 0.1  # # GENETIC ALGORITHM KEEP RATIO
mutation_ratio = 0.1  # MUTATION RATIO
mutation_sigma = 0.1  # MUTATION GAUSSIAN SIGMA
max_turns = 1000  # MAX PLAY TURNS PER TRIAL
neural_functions = [identity]  # NEURAL FUNCTION VECTOR


# PRE-PROCESSING
keep_individuals = int(keep_ratio * Npop)
Nw = 0
for i in range(len(Net_struct) - 1):
    nin = Net_struct[i]
    nout = Net_struct[i + 1]
    Nw += (nin + 1) * nout
all_weights = np.random.normal(loc=0.0, scale=1.0, size=(Npop, Nw))
agents = []
for agent_id in range(Npop):
    weights = all_weights[agent_id]
    agent = NeuralAgent(
        weights=weights, structure=Net_struct, neural_functions=neural_functions
    )
    agents.append(agent)
agent_functions = [agent.get_caller() for agent in agents]
func = agent_functions[0]
total_generations = 1

In [None]:
# DATA STORAGE
snake4 = FastSnake(Nrow=12, Ncol=12)
scores = np.zeros(Npop)
turns = np.zeros(Npop)
tries_scores = np.zeros(Ntries)
tries_turns = np.zeros(Ntries)
new_all_weights = np.zeros_like(all_weights)
turn_ids = np.array([-1.0, 0.0, 1.0])
for generation in range(Ngen):
    print(f"Generation: {total_generations}")
    scores[:] = 0.0
    turns[:] = 0.0
    for agent_id in range(Npop):  # tqdm.trange(Npop):
        tries_scores[:] = 0.0
        tries_turns[:] = 0.0
        agent_func = agent_functions[agent_id]
        for trial in range(Ntries):
            snake4.reset()
            Ncol = snake4.Ncol
            Nrow = snake4.Nrow
            snake4.fruit_position = (Nrow - 2) * Ncol + Ncol - 2
            turn = 0
            # weights = all_weights[agent_id]
            while snake4.status == 0:
                sensors = snake4.sensors()
                my_choice = turn_ids[argMax(agent_func(sensors))]
                snake4.turn(my_choice)
                turn += 1
                if turn >= max_turns:
                    break
            tries_scores[trial] = snake4.score
            tries_turns[trial] = turn
        scores[agent_id] = tries_scores.mean()
        turns[agent_id] = tries_turns.mean()
    perf = scores
    order = np.argsort(perf)[::-1]
    new_all_weights[:] = 0.0
    # SELECTION
    new_all_weights[:keep_individuals] = all_weights[order][:keep_individuals]
    # HYBRIDATION
    keep_range = np.arange(keep_individuals)
    for indiv in range(keep_individuals, Npop):
        parents = np.random.choice(keep_range, 2)
        while parents[1] == parents[0]:
            parents = np.random.choice(keep_range, 2)
        # parents = np.array([0, np.random.randint(1, keep_individuals)])
        pw = np.random.rand()
        # w = new_all_weights[parents[0]] * pw + (1.0 - pw) + new_all_weights[parents[1]]
        cross_over = np.random.randint(1, Nw - 1)
        # new_all_weights[indiv][:cross_over] = new_all_weights[parents[0]][:cross_over]
        # new_all_weights[indiv][cross_over:] = new_all_weights[parents[1]][cross_over:]
        # #= new_all_weights[parents[0]]
        new_all_weights[indiv] = (
            new_all_weights[parents[0]] * pw + (1.0 - pw) + new_all_weights[parents[1]]
        )
        # MUTATION:
        if np.random.rand() <= mutation_ratio:
            mutation_loc = np.random.randint(Nw)
            new_all_weights[indiv] *= np.random.normal(
                loc=1.0, scale=mutation_sigma, size=Nw
            )

        # new_all_weights[indiv] = w
    total_generations += 1
    all_weights[:] = new_all_weights
    print(f"=> best score = {scores.max()}")

    data = pd.DataFrame(
        {"score": scores[order], "turns": turns[order], "perf": perf[order]}
    )  # .sort_values( "perf", ascending=False    )
    print(data.head(5))
print("FINISHED")

In [None]:
snake5 = FastSnake(Nrow=12, Ncol=12)

# weights = all_weights[0]  # BEST AGENT
best_agent_func = agent_functions[0]
turn_ids = np.array([-1.0, 0.0, 1.0])


def updatefig(*args):
    sensors = snake5.sensors()
    my_choice = turn_ids[argMax(agent_func(sensors))]

    snake5.turn(my_choice)
    im2.set_array(snake5.grid)
    if snake5.status != 0:
        snake5.reset()
    return (im2,)


fig2, ax2 = plt.subplots()
ax2.axis("off")
im2 = plt.imshow(snake5.grid, interpolation="nearest", animated=True)
anim = animation.FuncAnimation(fig2, updatefig, frames=500, interval=50, blit=True)
plt.show()