In [15]:
import pygame
import numpy as np
import torch
import torch.nn as nn

In [61]:
windowwidth = 400
windowheight = 400
playersize = 10
startsize = 5
pygame.init()
win = pygame.display.set_mode((windowwidth, windowheight))
pygame.display.set_caption("snake")
font = pygame.font.Font(None, 50)
clock = pygame.time.Clock()

class snake(object):
    def __init__(self):
        self.pos = np.array([200, 200])
        self.dir = playersize*np.array([1,0])
        self.len = startsize
        self.prevpos = [np.array([200, 200])]
    
    def move(self):
        self.pos += self.dir
        self.prevpos.append(self.pos.copy())
        self.prevpos = self.prevpos[-self.len-1:]
        
    def checkdead(self):
        if self.pos[0] < 1 or self.pos[0] > 399:
            return True
        if self.pos[1] < 1 or self.pos[1] > 399:
            return True
        if list(self.pos) in [list(item) for item in self.prevpos[:-1]]:
            return True
        
    def reset(self):
        self.pos = np.array([200., 200.])
        self.dir = playersize*np.array([1.,0.])
        self.len = startsize
        self.prevpos = [np.array([200., 200.])]

        
class apple(object):
    def __init__(self, pos):
        self.pos = pos
        self.score = 0
    
    def eaten(self):
        self.pos = np.random.uniform(10., 390., 2.)
        self.score += 1
        
player = snake()
def redraw(goal, player):
    win.fill((0,0,0))
    for pos in player.prevpos:
        pygame.draw.rect(win, (0, 255, 0), (pos[0] - playersize//2, pos[1] - playersize//2, playersize, playersize ))
        pygame.draw.rect(win, (255, 0, 0), (goal.pos[0] - playersize//2, goal.pos[1] - playersize//2, playersize, playersize ))

def check_eaten(player, apple):
    if np.linalg.norm(player.pos - apple.pos) < 11:
        apple.eaten()
        player.len += 1
        
        
def rungame():
    goal = apple(np.random.uniform(10., 390., 2.))
    
    run=True
    while run:
        clock.tick(20)
        keys = pygame.key.get_pressed()
        
        if keys[pygame.K_a]:
            if (player.dir != np.array([1.,0.])).all():
                player.dir = playersize*np.array([-1.,0.])
        if keys[pygame.K_d]:
            if (player.dir != np.array([-1.,0.])).all():
                player.dir = playersize*np.array([1.,0.])
        if keys[pygame.K_w]:
            if (player.dir != np.array([0.,1.])).all():
                player.dir = playersize*np.array([0.,-1.])
        if keys[pygame.K_s]:
            if (player.dir != np.array([0.,-1.])).all():
                player.dir = playersize*np.array([0.,1.])
        
        
        player.move()
        check_eaten(player, goal)
        redraw(goal, player)
        pygame.display.update()
        
        dead = player.checkdead()
        if dead == True:
            player.reset()
        
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                run = False
                
        pygame.event.pump()

In [75]:
class actor(nn.Module):
    def __init__(self, hidden_layer_size):
        super(actor, self).__init__()
        self.fc1 = nn.Linear(8, hidden_layer_size)
        self.fc2 = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.fc3 = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, 5)
        self.relu = nn.ReLU()
        
    def forward(self, playerpos, applepos, proximity):
        x = torch.cat([playerpos, applepos, proximity])
        l1 = self.relu(self.fc1(x))
        l2 = self.relu(self.fc2(l1))
        l3 = self.relu(self.fc3(l2))
        l4 = self.fc4(l3)
        if self.training == True:
            return torch.exp(l4)/torch.sum(torch.exp(l4))
        else:
            move =  list(l4).index(torch.max(l4))
            zeros = torch.zeros(5)
            seros[move] = 1
            return zeros
        
act = actor(10)

def get_proximity(playerpos, prevpos):
    result = torch.zeros(4)
    right = playerpos + playersize*torch.tensor([-1.,0.])
    left = playerpos + playersize*torch.tensor([1.,0.])
    up = playerpos + playersize*torch.tensor([0.,-1.])
    down = playerpos + playersize*torch.tensor([0.,1.])
    
    prevlist = [list(prev) for prev in prevpos]
    if list(left) in prevlist or list(left)[0] < 1:
        result[0] = 1
    if list(right) in prevlist or list(right)[0] > 399:
        result[1] = 1
    if list(up) in prevlist or list(up)[1] < 1:
        result[2] = 1
    if list(down) in prevlist or list(down)[0] > 399:
        result[3] = 1
    return result

class critic(nn.Module):
    def __init__(self, hidden_layer_size):
        super(critic, self).__init__()
        self.fc1 = nn.Linear(8, hidden_layer_size)
        self.fc2 = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.fc3 = nn.Linear(hidden_layer_size, 1)
        self.relu = nn.ReLU()
    
    def forward(self, playerpos, applepos, proximity):
        x = torch.cat([playerpos, applepos, proximity])
        l1 = self.relu(self.fc1(x))
        l2 = self.relu(self.fc2(l1))
        l3 = self.fc3(l2)
        return l3
        

In [80]:
def rungame_ai():
    act.training = True
    goal = apple(np.random.uniform(10, 390, 2))
    
    run=True
    while run:
        clock.tick(20)
        
        playerpos = torch.from_numpy(player.pos).float()
        applepos = torch.from_numpy(goal.pos).float()
        prevpos = [torch.from_numpy(x).float() for x in player.prevpos]
        proximity = get_proximity(playerpos, prevpos)
        result = act(playerpos, applepos, proximity)
        move = torch.multinomial(result, 1)
        
        
        if move==0:
            if (player.dir != np.array([1,0])).all():
                player.dir = playersize*np.array([-1,0])
        if move==1:
            if (player.dir != np.array([-1,0])).all():
                player.dir = playersize*np.array([1,0])
        if move==2:
            if (player.dir != np.array([0,1])).all():
                player.dir = playersize*np.array([0,-1])
        if move==3:
            if (player.dir != np.array([0,-1])).all():
                player.dir = playersize*np.array([0,1])
        
        
        player.move()
        check_eaten(player, goal)
        redraw(goal, player)
        pygame.display.update()
        
        dead = player.checkdead()
        if dead == True:
            player.reset()
        
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                run = False
                
        pygame.event.pump()

In [81]:
rungame_ai()