In [14]:
import gym
from gym import spaces
from gym.utils import seeding
from collections import deque
from collections import defaultdict
import copy
%load_ext autoreload

%autoreload 2
%matplotlib notebook

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
import tensorflow as tf

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory

Using TensorFlow backend.


In [2]:
import numpy as np

In [3]:
class SnakeEnv(gym.Env):
    #### GYM IMPLEMENTATION ###
    metadata = {'render.modes': ['human', 'not_human']}
    def __init__(self):
    # Definitions for game.
    # Board 15 x 15 to start
        self.score = 0
        self.game_over = False 
        self.width = 10
        self.height = 10
        self.grid = np.array([np.zeros(self.width) for i in range(self.height)], dtype=object)
        self.snake = deque()
        # direction 0 = down, 1 = left, 2 = up, 3 = right
        self.direction = 0
    # Running score
        self.score = 0
    #Gym implementation members:
        self.action_space = spaces.Discrete(4)
        # each square can be empty, food, snake, snake head
        self.observation_space = spaces.Box(low = np.array([0,0,0]), high=np.array([self.width, self.height, 4]), dtype=np.int)

        self.reset()
    def step(self, action):
        ''' Perform one step of the game'''
        self.move(action)
        info = {}
        return self.grid, self.score, self.game_over, info
    def reset(self):
        self.direction = 0
        self.snake = deque()
        self.grid = np.array([np.zeros(self.width) for i in range(self.height)], dtype=object)
        i = np.random.randint(0,self.width- 2 ) + 1
        j = np.random.randint(0,self.height - 2) + 1 
        # * is a nonhead part of the snake
        self.grid[i,j] = "*"
        self.snake.append((i,j))
        # H is the head of the snake
        self.grid[i+1,j] = "H"
        self.snake.append((i+1,j))
        # X is food, set initial food
        k,l = i,j
        while ((i==k and j==l) or (i+1 ==k and j==l)):
            k = np.random.randint(0,self.width- 2 ) + 1
            l = np.random.randint(0,self.height - 2) + 1
        self.grid[k,l] = "X"
        self.food = (k,l)
    def updategrid(self):
        self.grid = np.array([np.zeros(self.width) for i in range(self.height)], dtype=object)
        temp_snake = self.snake.copy()
        self.grid[temp_snake[-1]] = "H"
        temp_snake.pop()
        for body in temp_snake:
            self.grid[body] = "*"
        self.grid[self.food] = "X"
    def render(self, mode='human', close=False):
        print(self.grid)
                        
    
    ### SNAKE IMPLEMENTATION ###
    def move(self, action):
        ## check direction, if direction is opposite of input, continue straight
        if (self.direction %2 == action %2):
            action = self.direction
        self.direction = action
        new_pos = self.snake[-1]
        if (action == 0): #down
            new_pos_row = self.snake[-1][0]+1
            new_pos_col = self.snake[-1][1]
        if (action == 1): #left
            new_pos_row = self.snake[-1][0]
            new_pos_col = self.snake[-1][1]-1
        if (action == 2): #up
            new_pos_row = self.snake[-1][0]-1
            new_pos_col = self.snake[-1][1]
        if (action == 3): #right
            new_pos_row = self.snake[-1][0]
            new_pos_col = self.snake[-1][1]+1
        # Check if snake CRASHES
        new_pos = (new_pos_row, new_pos_col)
        self.check_game_over(new_pos)
        if (self.game_over):
            self.updategrid()
            return
        self.snake.append(new_pos)
        
        # Check if snake grows
        if (new_pos != self.food):
            self.snake.popleft()
        else:
            self.score +=1
            self.add_food()
        self.updategrid()
    def add_food(self):
        x_coord = np.random.randint(0,self.width-1)
        y_coord = np.random.randint(0,self.height-1)
        self.food = (x_coord, y_coord)
        while(self.food in self.snake):
            x_coord = np.random.randint(0,self.width-1)
            y_coord = np.random.randint(0,self.height-1)
            self.food = (x_coord, y_coord)
            
    def check_game_over(self, new_pos):
        # Check if snake hits itself
        if new_pos in self.snake:
            self.game_over = True 
        # Check if snake hits the wall
        if ((new_pos[0]<0) or (new_pos[0]>=self.width) or (new_pos[1]<0) or (new_pos[1]>=self.height)):
            self.game_over = True
    
    
    
    

In [4]:
def random_game():
    env = SnakeEnv()
    while(not env.game_over):
        action = np.random.randint(0,4)
        env.step(action)
        env.render()
        print(action, env.game_over)
    return

In [5]:
def montecarlo(game, number_of_games, prints=False):
    def random_game_initial_score(game, initial_move):
        grid,_, lost, _ = game.step(initial_move)
        #print(grid)
        while( not lost):
            grid,_, lost, _ = game.step(np.random.randint(4))
            #print(grid)
        return initial_move, game.score
    def get_move(game):
        grids = []
        for i in range(number_of_games):
            grids.append(copy.deepcopy(game))
        move_scores = defaultdict(list)
        i=0
        for games in grids:
            initial_move, score = random_game_initial_score(games, i%4)
            move_scores[initial_move].append(score)
            i+=1
        score = -1
        for key, value in move_scores.items():
            if (prints == True):
                print(key, sum(value)/len(value))
            if (sum(value)/len(value) > score):
                move = key
                score = sum(value)/len(value)
        if (prints == True):
            print(move, score)
        return move
    states = []
    lost = False
    while (not lost):
        move = get_move(game)
        if (prints == True):
            game.render()
        _, _, lost, _ = game.step(move)
    return(game.score)
            

In [12]:
env = SnakeEnv()
montecarlo(env,100,True)


0 0.04
1 0.28
2 0.08
3 0.08
1 0.28
[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 'X' 0.0 0.0 '*' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 'H' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]
0 0.28
1 0.36
2 0.36
3 0.32
1 0.36
[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 'X' 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 'H' '*' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]
0 0.56
1 0.52
2 0.44
3 0.32
0 0.56
[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

0 2.04
1 2.04
2 2.92
3 2.08
2 2.92
[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 'X' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 'H' '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]
0 3.08
1 2.32
2 3.04
3 2.16
0 3.08
[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 'X' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 'H' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 '*' '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]
0 3.68
1 3.28
2 3.68
3 3.08
0 3.68
[[0.0 0.0 0.0 0.0 0.0 0.0 'X' 0.0

0 6.0
1 6.24
2 6.0
3 6.16
1 6.24
[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 '*' '*' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 '*' 0.0 0.0 0.0]
 [0.0 0.0 'X' 0.0 0.0 0.0 '*' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 '*' '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 'H' '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]
0 6.0
1 6.16
2 6.04
3 6.16
1 6.16
[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 '*' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 '*' 0.0 0.0 0.0]
 [0.0 0.0 'X' 0.0 0.0 0.0 '*' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 '*' '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 'H' '*' '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]
0 6.0
1 6.2
2 6.12
3 6.08
1 6.2
[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0

0 9.2
1 9.0
2 9.2
3 9.04
0 9.2
[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 'X' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 '*' '*' '*' 0.0 'H' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 '*' 0.0 '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 '*' '*' '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 '*' '*' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]
0 9.12
1 9.2
2 9.36
3 9.12
2 9.36
[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 'X' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 'H' 0.0 0.0]
 [0.0 0.0 0.0 0.0 '*' '*' 0.0 '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 '*' 0.0 '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 '*' '*' '*' 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 '*' '*' 0.0 0.0 0.0]
 [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]]
0 9.32
1 9.16
2 9.4
3 9.08
2 9.4
[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.

13

In [7]:
a,b

NameError: name 'a' is not defined

In [None]:
env.game_over

In [None]:
env.add_food()

In [8]:
env.food

(8, 8)

In [9]:
env.add_food()
env.food

(8, 1)