In [1]:
import gym
from gym import spaces
from gym.utils import seeding
from collections import deque
%matplotlib notebook

In [2]:
import numpy as np

In [66]:
class SnakeEnv(gym.Env):
    #### GYM IMPLEMENTATION ###
    metadata = {'render.modes': ['human', 'not_human']}
    def __init__(self):
    # Definitions for game.
    # Board 15 x 15 to start
        self.game_over = False 
        self.width = 10
        self.height = 10
        self.grid = np.array([np.zeros(self.width) for i in range(self.height)], dtype=object)
        self.snake = deque()
        # direction 0 = down, 1 = left, 2 = up, 3 = right
        self.direction = 0
    # Running score
        self.score = 0
    #Gym implementation members:
        self.action_space = spaces.Discrete(4)
        # each square can be empty, food, snake, snake head
        self.observation_space = spaces.Box(low = np.array([0,0,0]), high=np.array([self.width, self.height, 4]), dtype=np.int)

        self.reset()
    def step(self, action):
        return
    def reset(self):
        self.direction = 0
        self.snake = deque()
        self.grid = np.array([np.zeros(self.width) for i in range(self.height)], dtype=object)
        i = np.random.randint(0,self.width- 2 ) + 1
        j = np.random.randint(0,self.height - 2) + 1 
        # * is a nonhead part of the snake
        self.grid[i,j] = "*"
        self.snake.append((i,j))
        # H is the head of the snake
        self.grid[i+1,j] = "H"
        self.snake.append((i+1,j))
        # X is food, set initial food
        k,l = i,j
        while ((i==k and j==l) or (i+1 ==k and j==l)):
            k = np.random.randint(0,self.width- 2 ) + 1
            l = np.random.randint(0,self.height - 2) + 1
        self.grid[k,l] = "X"
        self.food = (k,l)
    def render(self, mode='human', close=False):
        self.grid = np.array([np.zeros(self.width) for i in range(self.height)], dtype=object)
        temp_snake = self.snake
        self.grid[temp_snake[-1]] = "H"
        temp_snake.pop()
        for body in temp_snake:
            self.grid[body] = "*"
        self.grid[self.food] = "X"
            
            
            
    
    ### SNAKE IMPLEMENTATION ###
    def move(self, action):
        ## check direction, if direction is opposite of input, continue straight
        if (self.direction == action %2):
            action = self.direction
        new_pos = self.snake[-1]
        if (action == 0): #down
            new_pos_row = self.snake[-1][0]+1
            new_pos_col = self.snake[-1][1]
        if (action == 1): #left
            new_pos_row = self.snake[-1][0]
            new_pos_col = self.snake[-1][1]-1
        if (action == 2): #up
            new_pos_row = self.snake[-1][0]-1
            new_pos_col = self.snake[-1][1]
        if (action == 3): #right
            new_pos_row = self.snake[-1][0]
            new_pos_col = self.snake[-1][1]+1
        # Check if snake CRASHES
        new_pos = (new_pos_row, new_pos_col)
        self.check_game_over(new_pos)
        self.snake.append(new_pos)
        
        # Check if snake grows
        if (new_pos != self.food):
            self.snake.popleft()
        else:
            self.add_food()
    def add_food(self):
        x_coord = np.random.randint(0,self.width-1)
        y_coord = np.random.randint(0,self.height-1)
        self.food = (x_coord, y_coord)
        while(self.food in self.snake):
            x_coord = np.random.randint(0,self.width-1)
            y_coord = np.random.randint(0,self.height-1)
            self.food = (x_coord, y_coord)
            
    def check_game_over(self, new_pos):
        # Check if snake hits itself
        if new_pos in self.snake:
            self.game_over = True 
        # Check if snake hits the wall
        if new_pos[0]<0 or new_pos[0]>self.width or new_pos[1]<0 or new_pos[1]>self.height:
            self.game_over = True
    
    
    
    

In [67]:
env = SnakeEnv()
env.grid

array([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, '*', 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 'H', 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 'X', 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=object)

In [68]:
env.move(0)
env.move(0)
env.move(0)

In [69]:
env.render()

In [71]:
env.grid

array([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, '*', 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, '*', 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 'H', 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 'X', 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=object)

In [22]:
env.add_food()

(7, 5) 7 5


In [23]:
env.food

(7, 5)

In [28]:
env.add_food()
env.food

(0, 5) 0 5


(0, 5)