In [2]:
from torch import nn
import torch
from collections import deque
import itertools
import numpy as np
import random
from torch.utils.tensorboard import SummaryWriter
from game import SnakeGame, ACTIONS


pygame 2.1.0 (SDL 2.0.16, Python 3.9.16)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
GAMMA = 0.9               
BATCH_SIZE = 1000          
BUFFER_SIZE = 100_000      
MIN_REPLAY_SIZE = 1_000    
EPSILON_START = 1.0 
EPSILON_END = 0.01
EPSILON_DECAY = 10_000      
TARGET_UPDATE_FREQ = 1_000  
LEARNING_RATE = 0.001
TARGET_SAVE_FREQ = TARGET_UPDATE_FREQ*25
MODELS_DIR = '../saved_models'
LOG_DIR = './logs/snake_3_1'
DOUBLE = True


In [3]:
summary_writer = SummaryWriter(LOG_DIR)

In [28]:
class Network(nn.Module):
    def __init__(self,env):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3,32,kernel_size=2,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2d(32,64,kernel_size=2,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(3,3),
            nn.Conv2d(64,64,kernel_size=2,stride=1,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,64,kernel_size=2,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(3,3),
            nn.Flatten(),
            nn.Linear(1600, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, env.get_action_space_size()),
        )

    def forward(self, x):
        return self.net(x)
    def act(self,obs):
        obs_t = torch.as_tensor(obs, dtype=torch.float32)
        q_values = self.forward(obs_t.unsqueeze(0))
        max_q_index = torch.argmax(q_values, dim=1)[0]
        action = max_q_index.detach().item()
        return action

In [29]:
env = SnakeGame()

replay_buffer = deque(maxlen=BUFFER_SIZE)

reward_buffer = deque([],maxlen=100)
score_buffer = deque([],maxlen=100)
episode_reward = 0.0

In [30]:
online_net = Network(env)
target_net = Network(env)
target_net.load_state_dict(online_net.state_dict())

optimizer = torch.optim.Adam(online_net.parameters(), lr=LEARNING_RATE)

In [23]:
# iniziatlize Replay Buffer
obs = env.reset()

for _ in range(MIN_REPLAY_SIZE):
    action = env.sample_from_action_space()

    new_obs, rew, done, _ = env.step(action)
    
    
    transition = (obs,action,rew,done,new_obs)
    
    replay_buffer.append(transition)
    obs = new_obs

    if done:
        obs = env.reset()

KeyboardInterrupt: 

In [8]:
# training loop

obs = env.reset()

for step in itertools.count():
    epsilon = np.interp(step,[0,EPSILON_DECAY],[EPSILON_START,EPSILON_END])
    rnd_sample = random.random()

    if rnd_sample < epsilon:
        action = env.sample_from_action_space()
    else:
        action = online_net.act(obs)

    new_obs, rew, done, infos = env.step(action)

    transition = (obs,action,rew,done,new_obs)
    replay_buffer.append(transition)
    obs = new_obs

    episode_reward += rew

    if done:
        obs = env.reset()
        score_buffer.append(infos['score'])
        reward_buffer.append(episode_reward)
        episode_reward = 0.0


    # start gradient step 
    transitions = random.sample(replay_buffer,BATCH_SIZE)
    
    obses = np.asarray([t[0] for t in transitions])
    actions = np.asarray([t[1] for t in transitions])
    rewards = np.asarray([t[2] for t in transitions])
    dones = np.asarray([t[3] for t in transitions])
    new_obses = np.asarray([t[4] for t in transitions])


    obses_t = torch.as_tensor(obses, dtype=torch.float32)
    
    actions_t = torch.as_tensor(actions, dtype=torch.int64).unsqueeze(-1)
    rewards_t = torch.as_tensor(rewards, dtype=torch.float32).unsqueeze(-1)
    dones_t = torch.as_tensor(dones, dtype=torch.float32).unsqueeze(-1)
    new_obses_t = torch.as_tensor(new_obses, dtype=torch.float32)

    # compute targets
    with torch.no_grad():
        if DOUBLE:
            target_online_q_values = online_net(new_obses_t)
            target_online_best_q_indeces = target_online_q_values.argmax(dim=1,keepdim=True)

            targets_target_q_values = target_net(new_obses_t)
            targets_selected_q_values = torch.gather(input=targets_target_q_values,dim=1,index=target_online_best_q_indeces)
            targets = rewards_t + GAMMA * (1-dones_t) * targets_selected_q_values
        else:
            target_q_values = target_net(new_obses_t)
            max_target_q_values = target_q_values.max(dim=1,keepdim=True)[0]
            targets = rewards_t + GAMMA * (1-dones_t) * max_target_q_values

    # loss

    q_values = online_net(obses_t)

    action_q_values = torch.gather(input = q_values, dim=1, index = actions_t)

    loss = nn.functional.smooth_l1_loss(action_q_values,targets)

    # gradient step

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # update target network if needed
    if step % TARGET_UPDATE_FREQ == 0:
        target_net.load_state_dict(online_net.state_dict())
        
    # checkpointing
    if step % TARGET_SAVE_FREQ == 0:
        print("Saving target net")
        torch.save(target_net.state_dict(), MODELS_DIR+"/snake_target_net_.pth")
    
    # Logging
    if step % 1000 == 0:
        rew_mean = np.mean(reward_buffer)
        score_mean = np.mean(score_buffer)
        print()
        print('Step', step)
        print('Avg Rew',rew_mean)
        print('Avg Score',score_mean)
        summary_writer.add_scalar('avg_rew', rew_mean, global_step=step)
        summary_writer.add_scalar('avg_score', score_mean, global_step=step)


    

RuntimeError: The size of tensor a (1000) must match the size of tensor b (43) at non-singleton dimension 2

In [13]:
targets_selected_q_values.shape

torch.Size([1000, 1, 43, 3])

In [17]:
new_obses_t.shape

torch.Size([1000, 3, 40, 40])

In [16]:
target_net(new_obses_t).shape

torch.Size([1000, 128, 43, 3])

In [31]:
target_net(torch.rand((1,3,40,40))).shape

torch.Size([1, 3])

In [11]:
dones_t.shape

torch.Size([1000, 1])

In [None]:
def save_checkpoint_step(target_net,optimizer,step,replay_buffer,reward_buffer,score_buffer):
    checkpoint = {
        'target_net': target_net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'step': step,
        'replay_buffer': replay_buffer,
        'reward_buffer': reward_buffer,
        'score_buffer': score_buffer
    }
    torch.save(checkpoint, MODELS_DIR+"/snake_checkpoint_step_{}.pth".format(step))

save_checkpoint_step(target_net,optimizer,step,replay_buffer,reward_buffer,score_buffer)