In [1]:
from torch import nn
import torch
from collections import deque
import itertools
import numpy as np
import random
from torch.utils.tensorboard import SummaryWriter
from game import SnakeGame, ACTIONS


pygame 2.2.0 (SDL 2.0.22, Python 3.9.16)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
GAMMA = 0.9               
BATCH_SIZE = 1000          
BUFFER_SIZE = 100_000      
MIN_REPLAY_SIZE = 1_000    
EPSILON_START = 1.0 
EPSILON_END = 0.01
EPSILON_DECAY = 5_000      
TARGET_UPDATE_FREQ = 1_000  
LEARNING_RATE = 0.001
TARGET_SAVE_FREQ = TARGET_UPDATE_FREQ*25
MODELS_DIR = '../saved_models'
LOG_DIR = './logs/snake_1_0'


In [None]:
summary_writer = SummaryWriter(LOG_DIR)

In [3]:
class Network(nn.Module):
    def __init__(self,env):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(env.get_observation_space_size(), 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, env.get_action_space_size()),
        )

    def forward(self, x):
        return self.net(x)
    def act(self,obs):
        obs_t = torch.as_tensor(obs, dtype=torch.float32)
        q_values = self.forward(obs_t.unsqueeze(0))
        max_q_index = torch.argmax(q_values, dim=1)[0]
        action = max_q_index.detach().item()
        return action

In [4]:
env = SnakeGame()

replay_buffer = deque(maxlen=BUFFER_SIZE)

reward_buffer = deque([],maxlen=100)
score_buffer = deque([],maxlen=100)
episode_reward = 0.0

In [5]:
online_net = Network(env)
target_net = Network(env)
target_net.load_state_dict(online_net.state_dict())

optimizer = torch.optim.Adam(online_net.parameters(), lr=LEARNING_RATE)

In [6]:
# iniziatlize Replay Buffer
obs = env.reset()

for _ in range(MIN_REPLAY_SIZE):
    action = env.sample_from_action_space()

    new_obs, rew, done, _ = env.step(action)
    
    
    transition = (obs,action,rew,done,new_obs)
    
    replay_buffer.append(transition)
    obs = new_obs

    if done:
        obs = env.reset()

In [7]:
# training loop

obs = env.reset()

for step in itertools.count():
    epsilon = np.interp(step,[0,EPSILON_DECAY],[EPSILON_START,EPSILON_END])
    rnd_sample = random.random()

    if rnd_sample < epsilon:
        action = env.sample_from_action_space()
    else:
        action = online_net.act(obs)

    new_obs, rew, done, infos = env.step(action)

    transition = (obs,action,rew,done,new_obs)
    replay_buffer.append(transition)
    obs = new_obs

    episode_reward += rew

    if done:
        obs = env.reset()
        score_buffer.append(infos['score'])
        reward_buffer.append(episode_reward)
        episode_reward = 0.0


    # start gradient step 
    transitions = random.sample(replay_buffer,BATCH_SIZE)
    
    obses = np.asarray([t[0] for t in transitions])
    actions = np.asarray([t[1] for t in transitions])
    rewards = np.asarray([t[2] for t in transitions])
    dones = np.asarray([t[3] for t in transitions])
    new_obses = np.asarray([t[4] for t in transitions])


    obses_t = torch.as_tensor(obses, dtype=torch.float32)
    
    actions_t = torch.as_tensor(actions, dtype=torch.int64).unsqueeze(-1)
    rewards_t = torch.as_tensor(rewards, dtype=torch.float32).unsqueeze(-1)
    dones_t = torch.as_tensor(dones, dtype=torch.float32).unsqueeze(-1)
    new_obses_t = torch.as_tensor(new_obses, dtype=torch.float32)

    # compute targets

    target_q_values = target_net(new_obses_t)

    max_target_q_values = target_q_values.max(dim=1,keepdim=True)[0]

    targets = rewards_t + GAMMA * (1-dones_t) * max_target_q_values

    # loss

    q_values = online_net(obses_t)

    action_q_values = torch.gather(input = q_values, dim=1, index = actions_t)

    loss = nn.functional.smooth_l1_loss(action_q_values,targets)

    # gradient step

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # update target network if needed
    if step % TARGET_UPDATE_FREQ == 0:
        target_net.load_state_dict(online_net.state_dict())
        
    # checkpointing
    if step % TARGET_SAVE_FREQ == 0:
        print("Saving target net")
        torch.save(target_net.state_dict(), MODELS_DIR+"/snake_target_net_.pth")
    
    # Logging
    if step % 1000 == 0:
        rew_mean = np.mean(reward_buffer)
        score_mean = np.mean(score_buffer)
        print()
        print('Step', step)
        print('Avg Rew',rew_mean)
        print('Avg Score',score_mean)
        summary_writer.add_scalar('avg_rew', rew_mean, global_step=step)
        summary_writer.add_scalar('avg_score', score_mean, global_step=step)


    

Saving target net

Step 0
Avg Rew 0.0

Step 1000
Avg Rew -8.695652173913043

Step 2000
Avg Rew -9.387755102040817

Step 3000
Avg Rew -9.375

Step 4000
Avg Rew -9.2

Step 5000
Avg Rew -9.3

Step 6000
Avg Rew -9.2

Step 7000
Avg Rew -9.3

Step 8000
Avg Rew -9.2

Step 9000
Avg Rew -9.1

Step 10000
Avg Rew -9.1

Step 11000
Avg Rew -9.1

Step 12000
Avg Rew -9.1

Step 13000
Avg Rew -9.1

Step 14000
Avg Rew -9.1

Step 15000
Avg Rew -9.1

Step 16000
Avg Rew -9.1

Step 17000
Avg Rew -9.1

Step 18000
Avg Rew -9.1

Step 19000
Avg Rew -9.1

Step 20000
Avg Rew -9.1

Step 21000
Avg Rew -9.1

Step 22000
Avg Rew -9.0

Step 23000
Avg Rew -9.0

Step 24000
Avg Rew -8.7
Saving target net

Step 25000
Avg Rew -8.4

Step 26000
Avg Rew -8.4

Step 27000
Avg Rew -8.3

Step 28000
Avg Rew -8.3

Step 29000
Avg Rew -8.5

Step 30000
Avg Rew -8.4

Step 31000
Avg Rew -8.4

Step 32000
Avg Rew -8.3

Step 33000
Avg Rew -8.3

Step 34000
Avg Rew -8.2

Step 35000
Avg Rew -8.2

Step 36000
Avg Rew -8.1

Step 37000
Avg Rew -8.

KeyboardInterrupt: 