In [1]:
from torch import nn
import torch
from collections import deque
import itertools
import numpy as np
import random
from torch.utils.tensorboard import SummaryWriter
from car import CarEnv,WIN

pygame 2.2.0 (SDL 2.0.22, Python 3.9.16)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
GAMMA = 0.99               
BATCH_SIZE = 64          
BUFFER_SIZE = 100_000      
MIN_REPLAY_SIZE = 1_000    
EPSILON_START = 1.0 
EPSILON_END = 0.1
EPSILON_DECAY = 8_000      
TARGET_UPDATE_FREQ = 1_000  
LEARNING_RATE = 0.001
TARGET_SAVE_FREQ = TARGET_UPDATE_FREQ*25
MODELS_DIR = '../saved_models'
LOG_DIR = '../logs/car_1_1'

In [3]:
summary_writer = SummaryWriter(LOG_DIR)

In [4]:
class Network(nn.Module):
    def __init__(self,env):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(env.get_observation_space_size(), 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, env.get_action_space_size()),
        )

    def forward(self, x):
        return self.net(x)
    def act(self,obs):
        obs_t = torch.as_tensor(obs, dtype=torch.float32)
        q_values = self.forward(obs_t.unsqueeze(0))
        max_q_index = torch.argmax(q_values, dim=1)[0]
        action = max_q_index.detach().item()
        return action

In [5]:
env = CarEnv()

replay_buffer = deque(maxlen=BUFFER_SIZE)

reward_buffer = deque([],maxlen=100)
score_buffer = deque([],maxlen=100)
episode_reward = 0.0

In [6]:
online_net = Network(env)
target_net = Network(env)
target_net.load_state_dict(online_net.state_dict())

optimizer = torch.optim.Adam(online_net.parameters(), lr=LEARNING_RATE)

In [7]:
# iniziatlize Replay Buffer
obs = env.reset()

for _ in range(MIN_REPLAY_SIZE):
    action = env.sample_from_action_space()

    new_obs, rew, done = env.step(action)    
    
    transition = (obs,action,rew,done,new_obs)
    
    replay_buffer.append(transition)
    obs = new_obs

    if done:
        obs = env.reset()

In [8]:
# training loop

obs = env.reset()

for step in itertools.count():
    epsilon = np.interp(step,[0,EPSILON_DECAY],[EPSILON_START,EPSILON_END])
    rnd_sample = random.random()

    if rnd_sample < epsilon:
        action = env.sample_from_action_space()
    else:
        action = online_net.act(obs)

    new_obs, rew, done = env.step(action)

    transition = (obs,action,rew,done,new_obs)
    replay_buffer.append(transition)
    obs = new_obs

    episode_reward += rew

    if done:
        obs = env.reset()
        reward_buffer.append(episode_reward)
        episode_reward = 0.0


    # start gradient step 
    transitions = random.sample(replay_buffer,BATCH_SIZE)
    
    obses = np.asarray([t[0] for t in transitions])
    actions = np.asarray([t[1] for t in transitions])
    rewards = np.asarray([t[2] for t in transitions])
    dones = np.asarray([t[3] for t in transitions])
    new_obses = np.asarray([t[4] for t in transitions])


    obses_t = torch.as_tensor(obses, dtype=torch.float32)
    
    actions_t = torch.as_tensor(actions, dtype=torch.int64).unsqueeze(-1)
    rewards_t = torch.as_tensor(rewards, dtype=torch.float32).unsqueeze(-1)
    dones_t = torch.as_tensor(dones, dtype=torch.float32).unsqueeze(-1)
    new_obses_t = torch.as_tensor(new_obses, dtype=torch.float32)

    # compute targets

    target_q_values = target_net(new_obses_t)

    max_target_q_values = target_q_values.max(dim=1,keepdim=True)[0]

    targets = rewards_t + GAMMA * (1-dones_t) * max_target_q_values

    # loss

    q_values = online_net(obses_t)

    action_q_values = torch.gather(input = q_values, dim=1, index = actions_t)

    loss = nn.functional.smooth_l1_loss(action_q_values,targets)

    # gradient step

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # update target network if needed
    if step % TARGET_UPDATE_FREQ == 0:
        target_net.load_state_dict(online_net.state_dict())
        
    # checkpointing
    if step % TARGET_SAVE_FREQ == 0:
        print("Saving target net")
        torch.save(target_net.state_dict(), MODELS_DIR+"/snake_target_net_.pth")
    
    # Logging
    if step % 1000 == 0:
        rew_mean = np.mean(reward_buffer)
        print()
        print('Step', step)
        print('Avg Rew',rew_mean)
        summary_writer.add_scalar('avg_rew', rew_mean, global_step=step)


    

Saving target net

Step 0
Avg Rew nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)



Step 1000
Avg Rew -630.0

Step 2000
Avg Rew -405.5

Step 3000
Avg Rew -516.3333333333334

Step 4000
Avg Rew -411.0

Step 5000
Avg Rew -230.42857142857142

Step 6000
Avg Rew -107.5

Step 7000
Avg Rew -39.526315789473685

Step 8000
Avg Rew -5.2592592592592595

Step 9000
Avg Rew 17.885714285714286

Step 10000
Avg Rew 31.666666666666668

Step 11000
Avg Rew 42.5625

Step 12000
Avg Rew 51.113207547169814

Step 13000
Avg Rew 62.25

Step 14000
Avg Rew 72.83333333333333

Step 15000
Avg Rew 79.6

Step 16000
Avg Rew 82.65217391304348

Step 17000
Avg Rew 83.69444444444444

Step 18000
Avg Rew 89.78666666666666

Step 19000
Avg Rew 92.96103896103897

Step 20000
Avg Rew 101.875

Step 21000
Avg Rew 101.875

Step 22000
Avg Rew 107.20987654320987

Step 23000
Avg Rew 111.97560975609755

Step 24000
Avg Rew 118.75581395348837
Saving target net

Step 25000
Avg Rew 126.94252873563218

Step 26000
Avg Rew 133.79775280898878

Step 27000
Avg Rew 139.43333333333334

Step 28000
Avg Rew 139.43333333333334

Step 290

KeyboardInterrupt: 