In [1]:
import torch
import torch.nn as nn
import numpy as np
import random
from copy import deepcopy
import time
import math
import tkinter as tk

Classes for models with different sizes:

In [2]:
class Interpreter(nn.Module):   #conv + choices
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(291, 4096)
        self.linear2 = nn.Linear(4096, 4096)
        self.linear3 = nn.Linear(4096, 4096)
        self.linear4 = nn.Linear(4096, 4096)
        self.linear8 = nn.Linear(4096, 1)
        
    def forward(self, x):
        x = nn.ReLU()(self.linear1(x))
        x = nn.ReLU()(self.linear2(x))
        x = nn.ReLU()(self.linear3(x))
        x = nn.ReLU()(self.linear4(x))
        x = self.linear8(x)
        #x = nn.Tanh()(x)
        return x
    
class Interpreter_3_layer(nn.Module):   #conv + choices
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(291, 4096)
        self.linear2 = nn.Linear(4096, 4096)
        self.linear3 = nn.Linear(4096, 1)
        
    def forward(self, x):
        x = nn.ReLU()(self.linear1(x))
        x = nn.ReLU()(self.linear2(x))
        x = self.linear3(x)
        #x = nn.Tanh()(x)
        return x
    
class Interpreter_3_layer_small(nn.Module):  #probably the best one
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(291, 1024)
        self.linear2 = nn.Linear(1024, 1024)
        self.linear3 = nn.Linear(1024, 1)
        
    def forward(self, x):
        x = nn.ReLU()(self.linear1(x))
        x = nn.ReLU()(self.linear2(x))
        x = self.linear3(x)
        #x = nn.Tanh()(x)
        return x    

create a model instance

In [3]:
interpreter3 = Interpreter_3_layer_small()

In [14]:
def find_one(tensor):   
    #randomly chooses an entry that is not 1, used to generate food for snake
    
    tensor1 = tensor.squeeze()
    assert len(tensor1.shape) == 2, 'wrong tensor shape'
    y = tensor.shape[0]
    found = []
    for i in range(y):
        for j, value in enumerate(tensor1[i]):
            if value == 1:
                found.append((i, j))
    assert len(found) == 1, 'more than one value found'
    return found[0]
        

def score_delay(tensor, delay_value):  
    #returns tensor of scores based on input of states
    #with delay
    
    scores = []
    for tsr in tensor:
        (y2, x2) = find_one(tsr[2]) #next head coords
        (fy, fx) = find_one(tsr[4]) #food coords

        df2 = abs(x2 - fx) + abs(y2 - fy)
        if df2 == 0:
            scores.append(10)
        else:
            scores.append(0)
    scores.append(-5)
    if len(scores) > 10000:
        print('scores getting very long -> snake probably looping')
    rewards = []
    for i in range(len(scores)):
        value = 0
        for j in range(i, len(scores)):
            if scores[j] != 0:
                value += scores[j] * float(np.power(delay_value, j - i))
        rewards.append(value)
    return torch.tensor(rewards)
        
    
def score(tensor):
    #returns tensor of scores based on input of states
    scores = []
    for tsr in tensor:
        (y1, x1) = find_one(tsr[5]) #initial head coords
        (y2, x2) = find_one(tsr[2]) #next head coords
        (fy, fx) = find_one(tsr[4]) #food coords

        df1 = abs(x1 - fx) + abs(y1 - fy)
        df2 = abs(x2 - fx) + abs(y2 - fy)
        
        reward = 0
        if df2 == 0:
            reward = 10
        elif df1 > df2:
            reward = 1
        else:
            reward = -1
        scores.append(reward)
    scores = scores + [-5]
    return torch.tensor(scores)

def encode(current, previous, routes) -> torch.tensor:
    assert current.shape[0] > 0, previous.shape[0] > 0
    flattened = torch.cat((current.flatten(1), previous.flatten(1), routes.flatten(1)), 1)
    return flattened
    

def find_empty(tensor) -> tuple:  #(y, x)
    #randomly chooses an entry == 0
    tensor1 = tensor[0][0] + tensor[0][1] + tensor[0][2]
    choices = []
    for y, item1 in enumerate(tensor1):
        for x, item2 in enumerate(tensor1[y]):
            if item2 == 0:
                choices.append((y, x))
    return random.choice(choices)


class Snake():
    def __init__(self, cuda = False):
        self.game_over = False
        self.state_tensor = self.state_tensor = torch.zeros((1, 3, 6, 8))
        #dims: 0 - snake body, 1 - food,  2 - head
        self.head_coords = [1, 1]
        self.state_tensor[0][2][self.head_coords[0]][self.head_coords[1]] = 1
        self.food_coords = find_empty(self.state_tensor)
        self.state_tensor[0][1][self.food_coords[0]][self.food_coords[1]] = 1
        self.orientation = [1, 0]
        self.segments = [deepcopy(self.head_coords)]
        self.state_tensor_previous = deepcopy(self.state_tensor)
        self.points = 0
        
def game(model, iterations, naive_scoring = True,
         delay_value = 0.6, epsilon = 0,
         argmax = False, chopoff = 0):
    #naive scoring helps learning early
    #epsilon determines probability to use softmax of predictions
    #as probabilities instead of choosing based on argmax
    #chopoff determines the initial percentage of data to remove (failed experiment)
    #delay value is used in delay scoring
    #argmax - whether to use argmax for choosing predictions
    
    data_iter = []
    data_iter_score = []
    points = []
    for i in range(iterations):
        data_iter_current = []
        snake = Snake()
        while not snake.game_over:
            predictions = []
            for j in range(3):
                routes = torch.zeros((1, 3))
                routes[0][j] = 1
                inputs = encode(snake.state_tensor, snake.state_tensor_previous, routes)
                value = model(inputs)
                predictions.append(value)
            predictions = nn.Softmax(0)(torch.tensor(predictions))
            
            choice = random.random()
            if choice < predictions[0].item():
                predicted_route = 0
            elif choice < predictions[0].item() + predictions[1].item():
                predicted_route = 1
            else:
                predicted_route = 2
            
            pick = random.random()
            if argmax and not pick < epsilon:
                predicted_route = torch.argmax(predictions).item()

            
            if predicted_route == 0: #right
                if snake.orientation == [-1, 0]:
                    snake.orientation = [0, 1]
                elif snake.orientation == [1, 0]:
                    snake.orientation = [0, -1]
                elif snake.orientation == [0, 1]:
                    snake.orientation = [1, 0]
                else:
                    snake.orientation = [-1, 0]
            elif predicted_route == 2: #left
                if snake.orientation == [-1, 0]:
                    snake.orientation = [0, -1]
                elif snake.orientation == [1, 0]:
                    snake.orientation = [0, 1]
                elif snake.orientation == [0, 1]:
                    snake.orientation = [-1, 0]
                else:
                    snake.orientation = [1, 0]
            else:
                pass
            data_iter_current.append([deepcopy(snake.state_tensor),
                                      deepcopy(snake.state_tensor_previous),
                                      predicted_route, snake.orientation])

            snake.state_tensor_previous = deepcopy(snake.state_tensor)
            
            #print(score(snake.state_tensor, snake.state_tensor_previous))
            nextt = deepcopy(snake.head_coords) 
            nextt[0], nextt[1] = deepcopy(nextt[0] + snake.orientation[0]), deepcopy(nextt[1] + snake.orientation[1])
            
            if nextt[0] in (-1, 6) or nextt[1] in (-1, 8): #if oob
                snake.game_over = True
                break
            elif snake.state_tensor[0][0][nextt[0]][nextt[1]] == 1:  #if snake crossed
                snake.game_over = True
                break
            
            elif snake.state_tensor[0][1][nextt[0]][nextt[1]] == 1:  #if food eaten
                snake.segments = [deepcopy(nextt)] + deepcopy(snake.segments)
                #new head added on front and last segment remains
                snake.state_tensor[0][2][snake.head_coords[0]][snake.head_coords[1]] = 0
                #old head removed from state tensor
                snake.state_tensor[0][0][snake.head_coords[0]][snake.head_coords[1]] = 1
                #head replaced with body segment
                snake.state_tensor[0][1][nextt[0]][nextt[1]] = 0     #food removed
                snake.state_tensor[0][2][nextt[0]][nextt[1]] = 1     #new head added
                snake.head_coords = deepcopy(nextt)
                snake.food_coords = find_empty(snake.state_tensor)
                snake.state_tensor[0][1][snake.food_coords[0]][snake.food_coords[1]] = 1 #new food added
                snake.points += 10
            else:   #if nothing happened 
                snake.state_tensor[0][0][snake.segments[-1][0]][snake.segments[-1][1]] = 0
                if len(snake.segments) > 1:
                    snake.state_tensor[0][0][snake.head_coords[0]][snake.head_coords[1]] = 1
                snake.segments = [deepcopy(nextt)] + deepcopy(snake.segments[0:-1])
                snake.state_tensor[0][2][snake.head_coords[0]][snake.head_coords[1]] = 0
                snake.state_tensor[0][2][nextt[0]][nextt[1]] = 1
                snake.head_coords = deepcopy(nextt)
        points.append(snake.points)
        
        if chopoff != 0:
            length = len(data_iter_current)
            begin = int((1 - chopoff) * length)
            data_iter = data_iter + data_iter_current[begin::]
            data_iter_score.append(data_iter_current[1 + begin::])
        else:   
            data_iter = data_iter + data_iter_current
            data_iter_score.append(data_iter_current[1::])
    
    
    #format data so it can be easily plugged into other functions
    scores = torch.tensor([])
    for episode in data_iter_score:
        list_for_score = []
        for item in episode:
            list_for_score.append(torch.cat((item[0], item[1]), 1))
        tensor_for_score = torch.cat([tensor for tensor in list_for_score], 0)
        if naive_scoring:
            score_evaluated = score(tensor_for_score)
        else:
            score_evaluated = score_delay(tensor_for_score, delay_value)
        scores = torch.cat((scores, score_evaluated), 0)
            
    data_new = []
    for item in data_iter:
        #print(f'{item[0] = }, {item[1] = }')
        routes = torch.zeros((1, 3))
        routes[0][item[2]] = 1
        processed = encode(item[0], item[1], routes)
        data_new.append(processed.unsqueeze(0))
    final_data = torch.cat([tensor for tensor in data_new], 0)
    
    return final_data.squeeze(), scores, np.mean(points), data_iter, data_iter_score
#returns state data, scores data, average points, raw data, raw data for scoring

#function for training the model
def train(model, iterations, epochs, loss_fn, optimizer, lr, naive_scoring = True,
          delay_value = 0.6, epsilon = 0, argmax = False, status_frequency = 250,
          cuda = False, weight_decay = 0, momentum = 0, chopoff = 0, skip = 1):
    optimizer = optimizer(params = model.parameters(),
                          lr = lr
                         )
    loss_fn = loss_fn()
    points_average = []
    t1 = time.time()
    
    for epoch in range(epochs):
        
        
        geym_done = game(model, iterations, naive_scoring = naive_scoring,
                         delay_value = delay_value, epsilon = epsilon,
                         argmax = argmax, chopoff = chopoff)
        
        points_average.append(geym_done[2])
        X_train, Y_train = geym_done[0][::skip], geym_done[1].unsqueeze(1)[::skip]
        
        
        model.train()
        predictions = model(X_train)
        optimizer.zero_grad()
        loss = loss_fn(predictions, Y_train)
        loss.backward()
        optimizer.step()
        model.eval()
        t2 = time.time()
        
        if epoch % status_frequency == 0:
            taim = time.time() - t1
            print(f'epoch: {epoch:<5} | time: {math.floor(taim / 60):>2}:{math.floor(taim % 60):>2}min | average score: {np.mean(points_average):<5}')
            points_average = []
            t1 = time.time()

#prints a readable view of the current and previous states
def view_tensor(data): #simplifies data to 2d tensors
    tensor1, tensor2 = data[0], data[1]
    current = tensor1[0][0] + 2 * tensor1[0][1] + 3 * tensor1[0][2]
    previous = tensor2[0][0] + 2 * tensor2[0][1] + 3 * tensor2[0][2]
    print(f'previous:\n{previous}\ncurrent:\n{current}')

#returns a readable tensor
def view_tensor_1(tensor):
        tensor = tensor[0][0] + tensor[0][1] * 2 + tensor[0][2] * 3
        return tensor

#calculates variance of a set of data, used to see if model defaults to a mean value instead of learning
def calc_variance(tensor):
    tensor1 = tensor.flatten()
    mean = torch.mean(tensor1).item()
    summed = 0
    for item in tensor1:
        summed += (item.item() - mean) ** 2
    variance = summed / (len(tensor1) - 1)
    return variance

def view(model):
    #function for viewing the model play a game
    root = tk.Tk()
    root.geometry('900x700')
    canvas = tk.Canvas(root, width = 800, height = 600)
    canvas.pack(side = 'top', fill="both", expand=True, pady = 30, padx = 45)
    button = tk.Button(root, text = 'Start', width = 100,
                       height = 20,
                       command = lambda x = canvas : snake_make(x))
    button.pack()
    
    
    
    def snake_draw(canvas, data, index, x_grid, y_grid, fill):
        tensor = data[index][0]
        tensor = view_tensor_1(tensor)
        index += 1
        for i in range(6):
            for j in range(8):
                value = int(tensor[i][j].item())
                canvas.create_rectangle((j * x_grid, i * y_grid, (j + 1) * x_grid, (i + 1) * y_grid)
                                        , fill = fill[value])
        if index < len(data) - 1:
            canvas.after(20, lambda : snake_draw(canvas, data, index, x_grid, y_grid, fill))
        
    
    def snake_make(canvas):
        fill = ['black', 'green', 'yellow', 'teal']
        geym = game(model, iterations = 1, epsilon = 0)

        x_grid = int(int(canvas['width']) / 8)
        y_grid = int(int(canvas['height']) / 6)
        index = 0
        
        canvas.after(20, lambda : snake_draw(canvas, geym[3], index, x_grid, y_grid, fill))

            
    root.mainloop()
            
    
    
    


3-Layer-small : Last Save at 389 average w/ argmax

In [10]:
torch.save(interpreter3.state_dict(), 'INSERT PATH HERE')

In [5]:
interpreter3.load_state_dict(torch.load('INSERT PATH HERE'))

<All keys matched successfully>

Training  
--adjust params as necessary, especially learning rate (decrease with time) and delay value (approach 1 with time)  
--naive scoring helps to kickstart the learning, so turn it off after model gets a grip

In [13]:
train(model = interpreter3,
      iterations = 2,
      epochs = 100000,
      loss_fn = nn.MSELoss,
      optimizer = torch.optim.Adam,
      lr = 0.00003,
      naive_scoring = False, 
      epsilon = 0.01,
      delay_value = 0.999,
      argmax = True,
      status_frequency = 256,
      skip = 1
     )

Watch the model play:

In [None]:
view(interpreter3)

In [79]:
t1 = time.time()
geym = game(interpreter3, iterations = 2, naive_scoring = False, argmax = True, epsilon = 0.01)
print(f'time taken: {time.time() - t1:.2f} s')

time taken: 2.30 s


In [80]:
geym[2]

405.0