In [1]:
# Imports the libraries and makes matplotlib appear in a second window:

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import tensorflow as tf
from tensorflow import keras
import random

from collections import deque

%matplotlib qt

In [2]:
class Snake ():
    def __init__ (self, nrows, ncols):
        
        self.nrows, self.ncols = nrows, ncols
        self.board = np.zeros((nrows,ncols)) #0 for empty, 1 for food and 2 for snake body
        self.movingDirection = np.zeros((4))
        self.headY, self.headX = nrows//2, ncols//2
        self.foodY, self.foodX = None, None
        self.tailY, self.tailX = self.headY, self.headX
        self.movingDirection = np.array(([1,0,0,0]))
        self.snakeBody = deque()
        self.snakeBody.appendleft((self.headY,self.headX))
        self.createFood()
    def checkIfBelongsToSnake (self,y, x):
        return bool(self.snakeBody.count((y,x)))

    def createFood(self):
        self.foodY, self.foodX = self.headY, self.headX
        while (self.checkIfBelongsToSnake(self.foodY,self.foodX)):
            self.foodY, self.foodX = random.randrange(self.nrows), random.randrange(self.ncols)
        self.board[self.foodY, self.foodX] = 1
        
    def move(self, action):
        if (action == 0): # Move up
            if ((self.headY-1)<0):# Snake dies if it tries to go out of the area
                self.__init__(self.nrows, self.ncols)
                reward = -1
                return self.board, True, reward
            else:
                self.headY, self.headX = self.headY-1, self.headX
        elif (action == 1): # Move right
            if ((self.headX+1)==self.ncols):# Snake dies if it tries to go out of the area
                self.__init__(self.nrows, self.ncols)
                reward = -1
                return self.board, True, reward
            else:
                self.headY, self.headX = self.headY , self.headX+1
        elif (action == 2): # Move down
            if ((self.headY+1)==self.nrows):# Snake dies if it tries to go out of the area
                self.__init__(self.nrows, self.ncols)
                reward = -1
                return self.board, True, reward
            else:
                self.headY, self.headX = self.headY+1, self.headX
        elif (action == 3): # Move Left
            if ((self.headX-1)<0):# Snake dies if it tries to go out of the area
                self.__init__(self.nrows, self.ncols)
                reward = -1
                return self.board, True, reward
            else:
                self.headY, self.headX = self.headY , self.headX-1
        else:
            raise Exception ("Invalid action")
        self.movingDirection = np.zeros((4))
        self.movingDirection[action] = 1
        
        if (self.checkIfBelongsToSnake(self.headY, self.headX) and not (self.headY == self.tailY and self.headX == self.tailX)):
            self.__init__(self.nrows, self.ncols) # Snake dies if bumps into a body part when its length is more than 1
            reward = -1
            return self.board, True, reward
        
        else:
            self.snakeBody.appendleft((self.headY, self.headX))
            
            if (self.board[self.headY, self.headX] == 1): # If it collects an apple
                reward = 1
                self.createFood()
            else:
                reward = 0
                self.board[self.tailY, self.tailX] = 0
                self.snakeBody.pop()
                if (self.snakeBody):
                    self.tailY, self.tailX = self.snakeBody[-1]
   
            self.board[self.headY, self.headX] = 2

            return self.board, False, reward
    def snakeVision (self):
        # This vision will be normalised afterwards in range 0 to 1 (a measurement of proximity)
        state = np.zeros((24)) # It indicates, in order, distance to wall, to snake body and to apple starting from up, then diagonally up-right and so on till diagonally up left (clockwise)
        
        #Direction Up:
        state[0] = self.headY + 1 #Distance to wall
        for index, i in enumerate(range (self.headY-1, -1, -1)):
            if (self.board[i, self.headX] ==2):
                state[1] = index + 1 #Distance to snake body
                break
        for index, i in enumerate(range (self.headY-1, -1, -1)):
            if (self.board[i, self.headX] ==1):
                state[2] = index + 1 #Distance to apple
                break
        
        #Direction Up-Right:
        state[3] = min(self.headY+1, self.ncols-self.headX) 
        for index, i in enumerate(range (1,int(state[3]))):
            if (self.board[self.headY-i, self.headX+i] ==2):
                state[4] = index + 1
                break
        for index, i in enumerate(range (1,int(state[3]))):
            if (self.board[self.headY-i, self.headX+i] ==1):
                state[5] = index + 1
                break
        
        #Direction RIGTH:
        state[6] = self.ncols - self.headX
        for index, i in enumerate(range (self.headX+1, self.ncols)):
            if (self.board[self.headY, i] ==2):
                state[7] = index + 1
                break
        for index, i in enumerate(range (self.headX+1, self.ncols)):
            if (self.board[self.headY, i] ==1):
                state[8] = index + 1
                break
                
        #Direction Down-Right:
        state[9] = min(self.nrows-self.headY, self.ncols-self.headX) 
        for index, i in enumerate(range (1,int(state[9]))):
            if (self.board[self.headY+i, self.headX+i] ==2):
                state[10] = index + 1
                break
        for index, i in enumerate(range (1,int(state[9]))):
            if (self.board[self.headY+i, self.headX+i] ==1):
                state[11] = index + 1
                break
        
        #Direction DOWN:
        state[12] = self.nrows - self.headY
        for index, i in enumerate(range (self.headY+1, self.nrows)):
            if (self.board[i, self.headX] ==2):
                state[13] = index + 1
                break
        for index, i in enumerate(range (self.headY+1, self.nrows)):
            if (self.board[i, self.headX] ==1):
                state[14] = index + 1
                break
                
        #Direction Down-Left:
        state[15] = min(self.nrows-self.headY, self.headX+1) 
        for index, i in enumerate(range (1,int(state[15]))):
            if (self.board[self.headY+i, self.headX-i] ==2):
                state[16] = index + 1
                break
        for index, i in enumerate(range (1,int(state[15]))):
            if (self.board[self.headY+i, self.headX-i] ==1):
                state[17] = index + 1
                break
        
        #Direction LEFT:
        state[18] = self.headX + 1 
        for index, i in enumerate(range (self.headX-1, -1, -1)):
            if (self.board[self.headY, i] ==2):
                state[19] = index + 1
                break
        for index, i in enumerate(range (self.headX-1, -1, -1)):
            if (self.board[self.headY, i] ==1):
                state[20] = index + 1
                break
                
        #Direction Up-Left:
        state[21] = min(self.headY+1, self.headX+1) 
        for index, i in enumerate(range (1,int(state[21]))):
            if (self.board[self.headY-i, self.headX-i] ==2):
                state[22] = index + 1
                break
        for index, i in enumerate(range (1,int(state[21]))):
            if (self.board[self.headY-i, self.headX-i] ==1):
                state[23] = index + 1
                break

        return self.normaliseVision(state)
    def normaliseVision (self,state):
        for i in range (0, len(state)):
            if (state[i] !=0): # If a vision input is 0, it means that we didn't find food or body, so 1/inf = 0 while preventing a division-by-0 exception from being raised
                state[i] = 1/state[i]
            # Instead of measuring distance, we are measuring proximity. 1 if next cell, 0.5 if 2 cells apart, 0.33 if 3 cells apart...
        return state
    def getFoodDirection(self): # Smelling sense
        # Same order than action (0 north, 1 east, 2 south, 3 west). 1 if food is in that direction and 0 if not. Two will be activated at the same time if food is diagonally and 1 if there is a straight line in between head and food
        foodDir = np.zeros((4))
        if (self.foodY<self.headY):
            foodDir[0] = 1 # North
        if (self.foodX>self.headX):
            foodDir[1] = 1 # East
        if (self.foodY>self.headY):
            foodDir[2] = 1 # South
        if (self.foodX<self.headX):
            foodDir[3] = 1 # West
        return foodDir

In [3]:
stateInput = keras.Input(shape = (24))
directionInput = keras.Input(shape=(4))
foodInput = keras.Input (shape = (4)) # Specifies in which direction is the next apple (two-hot vector)

dense1 = keras.layers.Dense(units = 256, activation = "relu") (stateInput)
dense2 = keras.layers.Dense(units = 256, activation = "relu") (dense1)
dense3 = keras.layers.Dense(units = 128, activation = "relu") (dense2)
dense4 = keras.layers.Dense(units = 8, activation = "relu") (dense3)
dense5 = keras.layers.Concatenate ()([dense4, directionInput,foodInput])

outputLayer = keras.layers.Dense (units = 4) (dense5)

model = keras.Model (inputs = [stateInput, directionInput, foodInput], outputs = outputLayer)
model.compile(optimizer=keras.optimizers.Adam(lr=0.001),loss ="mse")

print (model.summary())
staticModel = keras.models.clone_model( model, input_tensors=None, clone_function=None)

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 24)]         0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 256)          6400        input_1[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 256)          65792       dense[0][0]                      
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 128)          32896       dense_1[0][0]                    
______________________________________________________________________________________________

In [4]:
# MATPLOTLIB is interactive in the new window:

plt.ion()
plt.show()

snake = Snake (9, 9)


numberEpisodes = 3

discount = 0.99
counter = 0

rewards = []
model.load_weights("Snake1.h5")
for episode in range (numberEpisodes):

    done = False
    episodeReward = 0
    prevObs = None
    
    for step in range(1000):
        if (step == 999): # In case it is gonna time out, it has to reset
            snake.__init__(9,9)
            break
        if (done):
            break
        plt.clf()
        
        if (prevObs is None):
            action = random.choice([0,1,2,3])
        else:
            previousObs = np.expand_dims(prevObs, axis=0)
            previousDirec = np.expand_dims (prevDirections, axis = 0)
            previousFoodDir = np.expand_dims(prevFoodDirections, axis = 0)
            action = np.argmax(model.predict([previousObs,previousDirec,previousFoodDir]))

        state, done, reward = snake.move(action)
       
        arrayToShow = np.copy(state) # If we do not create a copy, the board object would be changed
        arrayToShow[snake.tailY, snake.tailX] = 4
        arrayToShow[snake.headY,snake.headX]  = 3
        
        cmap = mpl.colors.ListedColormap(["purple", "yellow", "mediumturquoise", "lightseagreen","turquoise"])
        norm = mpl.colors.BoundaryNorm(np.arange(0,6, 1), cmap.N) 

        plt.imshow(arrayToShow, cmap=cmap, norm=norm)

        plt.show()
       
        plt.gca().set_xticks([x-0.5 for x in np.arange(snake.ncols)], minor='true')
        plt.gca().set_yticks([y-0.5 for y in np.arange(snake.nrows)], minor='true')
        plt.grid(which="minor")
        plt.xticks([], [])
        plt.yticks([], [])
        
        plt.pause(0.02)
  
        obs = snake.snakeVision()
        obsDirections = snake.movingDirection
        foodDirections = snake.getFoodDirection()

        prevDirections = obsDirections
        prevObs = obs
        prevFoodDirections = foodDirections
        episodeReward += reward
        
    print (episode, " Episode Reward: ", episodeReward, ".Steps:", step)
    rewards.append(episodeReward)


0  Episode Reward:  22 .Steps: 223
1  Episode Reward:  20 .Steps: 176
2  Episode Reward:  23 .Steps: 165


In [5]:
plt.clf()
plt.plot (list(range(len(rewards))), rewards)
plt.plot (list(range(len(rewards))), [np.average(rewards)]*len(rewards))
plt.title("Rewards obtained along the learning process")
plt.xlabel("Episode number")
plt.ylabel("Reward obtained")
plt.legend(["Rewards", "Average: %.2f" % np.average(rewards)])
plt.show()