In [1]:
import numpy as np
import torch
from src.env import SnakeEnv

env = SnakeEnv(history_length=3,width=10,height=10,action_space_type='relative')

In [2]:
obs,_=env.reset(81)
obs,r,done,_=env.step(2)
obs,r,done,_=env.step(2)
x = torch.Tensor(obs).unsqueeze(0) # add batch dimension
obs,r,done,_=env.step(1)
y = torch.Tensor(obs).unsqueeze(0)
z = torch.cat((x,y))
z.shape # a batch of 2 samples

torch.Size([2, 3, 10, 10])

In [3]:
print(torch.argwhere(z[:,0,:,:] == -1))
print(torch.argwhere(z[:,0,:,:] == -1).shape)
# first index says where in the batch dim
# second index says where in the width dim
# third index says where in the height dim
# now just forget about the first entry of the last dimension
food_coord=torch.argwhere(z[:,0,:,:] == -1)[:,1:]
food_coord

tensor([[0, 2, 6],
        [1, 2, 6]])
torch.Size([2, 3])


tensor([[2, 6],
        [2, 6]])

In [4]:
print("First batch:\n",z[0])
print("Second batch:\n",z[1])
# how to get the head movement from this?

First batch:
 tensor([[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]],

        [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        

In [5]:
# frame == 1 tells us where the snake is; logical or between that of 0 and 1 frames tells us the union
# intersect the union with the 0 frame using a logical and we grab the head
head_coord_bool=torch.logical_and(torch.logical_or(z[:,0,:,:]==1 ,z[:,1,:,:]==1),z[:,0,:,:]==1)

# finally grab the coordinates
head_coord=torch.argwhere(head_coord_bool==1)[:,1:] # shaped like (batch_size,2)
head_coord

tensor([[1, 3],
        [1, 2]])

In [6]:
# now we want to grab the direction, so basically (dx,dy)
prev_head_coord_bool=torch.logical_and(torch.logical_or(z[:,1,:,:]==1 ,z[:,2,:,:]==1),z[:,1,:,:]==1)
prev_head_coord=torch.argwhere(prev_head_coord_bool==1)[:,1:]
direction = (head_coord - prev_head_coord).to(int)
direction

tensor([[ 0, -1],
        [ 0, -1]])

In [7]:
# look at the first batch sample:
# this is in the reference frame of the board, telling us that the food is
# [1 step in the positive x axis, 3 in the positive y axis]
# but the snake is going toward [0,-1] (negative y axis)
z[0][:2]

tensor([[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]],

        [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0

In [8]:
# now we define the vector that is the difference between the head and the food
diff = food_coord - head_coord
print(diff)

# if we define the snake to have x axis as the direction he's going to and y the lateral axis,with positive y going left
# then the food is actually one step to the left and 3 steps backward, so [-3,1]

# so define the transformation as k successive 90 degrees rotation ccw until the direction of the snake coincides with [1,0]
rotation_matrix_90ccw = torch.Tensor([[0,-1],[1,0]]).to(int)
rotation_matrix_90cw = torch.Tensor([[0,1],[-1,0]]).to(int) # for later use


# now, since each sample might need to be rotated a different number of times, there is no easy way to vectorize
# the matrix multiplication to vectors... i cant multiply the same rotation to every vector
# we need a for loop

for b in range(z.shape[0]):
    k=0
    while not torch.eq(torch.linalg.matrix_power(rotation_matrix_90ccw,k) @ direction[b],torch.Tensor([1,0])).all():
        k+=1
    diff[b] = torch.linalg.matrix_power(rotation_matrix_90ccw,k) @ diff[b]

# as expected
print(diff)

tensor([[1, 3],
        [1, 4]])
tensor([[-3,  1],
        [-4,  1]])


In [9]:
# now to grab the number of free cells available until a wall is met in each direction:
# a for loop here is needed as well along the batch size...

# initialize a tensor of shape (batch_size,3)

free_cells = torch.zeros((z.shape[0],3))

for b in range(z.shape[0]):
    
    directions = [
        rotation_matrix_90ccw @ direction[b],
        direction[b],
        rotation_matrix_90cw @ direction[b]
    ]
    for idd,d in enumerate(directions):
        
        k=1
        cell = head_coord[b] + k*d
        index = tuple(torch.cat((torch.Tensor([b,0]),cell)).to(int))
        while (z[index] == 0 or z[index] == -1)\
        and cell[0] in range(z.shape[2])\
        and cell[1] in range(z.shape[3]):
            index = tuple(torch.cat((torch.Tensor([b,0]),cell)).to(int))
            k+=1
            cell = head_coord[b] + k*d

        free_cells[b,idd] = k-1
        
    
free_cells

tensor([[8., 3., 1.],
        [8., 2., 1.]])

In [10]:
# now we have our 3+2 built features:
# the vector of the direction of the food relative to the snake (2 features)
# and the number of free cells along any possible direction that can be taken (3 features)

print(free_cells)
print(diff)

tensor([[8., 3., 1.],
        [8., 2., 1.]])
tensor([[-3,  1],
        [-4,  1]])


In [11]:
# concatenate them in a (batch_size,5) tensor
torch.cat((free_cells,diff),1)

tensor([[ 8.,  3.,  1., -3.,  1.],
        [ 8.,  2.,  1., -4.,  1.]])

In [12]:
from src.models.NN_custom_features import DQN

In [13]:
model = DQN(3)
model(z)

tensor([[-1.1590, -1.0395, -1.0871],
        [-1.1927, -0.7780, -1.2228]], grad_fn=<AddmmBackward0>)