In [1]:
import numpy as np
from gymnasium import spaces
import gymnasium as gym
import matplotlib.pyplot as plt

In [None]:
class Snake_Game(gym.Env):
    # other metadata avail, render.moldes unncessary if render() is not implemented
    metadata = {'render.modes' : ['console', 'rgb_array']}

    n_actions = 3

    # actions
    LEFT = 0
    STRAIGHT = 1
    RIGHT = 2

    # states
    EMPTY = 0
    SNAKE = 1
    WALL = 2
    FOOD = 3

    REWARD_WALL_HIT = -20
    REWARD_PER_STEP_TOWARDS_FOOD = 1 # avoid hitting walls on purpose
    REWARD_PER_FOOD = 50
    MAX_STEPS_AFTER_FOOD = 200 # avoid loop


    def grid_distance(self, pos1, pos2):
        # calculate euclidean distance between 2 points
        return np.linalg.norm(np.array(pos1, dtype=np.float32) - np.array(pos2, dtype=np.float32))
    
    
    def __init__(self, grid_size=10):
        super(Snake_Game, self).__init__()

        # steps init
        self.stepnum = 0
        self.last_food_step = 0

        # grid init
        self.grid_size = grid_size
        self.grid = np.zeros((self.grid_size, self.grid_size), dtype=np.uint8) + self.EMPTY # EMPTY is zero so it doesn't matter (in case its not)
        
        # wall init
        self.grid[0, :] = self.WALL # UP
        self.grid[:, 0] = self.WALL # LEFT
        self.grid[self.grid_size - 1, :] # DOWN
        self.grid[:, self.grid_size - 1] # RIGHT

        # snake init
        self.snake_coord = [(1, 1), (2, 1)] # top left

        for coord in self.snake_coord:
            self.grid[coord] - self.SNAKE

        # food init
        self.grid[3, 3] = self.FOOD

        # distance calculation
        self.head_dist_to_food = self.grid_distance(
            self.snake_coord[-1],
            np.argwhere(self.grid == self.FOOD)[0]
        )

        # save init setup
        self.init_grid = self.grid.copy()
        self.init_snake_coord = self.snake_coord.copy()

        # action space
        self.action_space = spaces.Discrete(self.n_actions)

        # observation(state) space
        self.observation_space = spaces.Dict(
            spaces={
                "position" : spaces.Box(low=0, high=(self.grid_size - 1), shape=(2,), dtype=np.int32),
                "direction" : spaces.Box(low=-1, high=1, shape=(2,), dtype=np.int32),
                "grid" : spaces.Box(low=0, high=3, shape=(self.grid_size, self.grid_size), dtype=np.uint8)
            }
        )
    

    def reset(self, seed=None):
        # to init position
        self.stepnum = 0
        self.last_food_step = 0
        self.grid = self.init_grid.copy()
        self.snake_coord = self.init_snake_coord.copy()

        self.head_dist_to_food = self.grid_distance(
            self.snake_coord[-1],
            np.argwhere(self.grid == self.FOOD)[0]
        )

        if seed is not None:
            np.random.seed(seed)

        obs = self._get_obs() # state space
        info = {}

        return obs, info


    def _get_obs(self):
        position = np.array(self.snake_coord[-1], dtype=np.int32)
        direction = np.array(self.snake_coord[-1]) - np.array(self.snake_coord[-2], dtype=np.int32)
        grid = self.grid 

        obs = {
            "position" : position,
            "direction" : direction,
            "grid" : grid
        }
        
        return obs