In [1]:
from collections import namedtuple, deque
from copy import deepcopy
from itertools import count
import math
import random

from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


from env import Env
from test_policy import evaluate_policy, print_policy

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
print(device)

writer = SummaryWriter("runs/deep-sarsa")
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

cuda


<torch._C.Generator at 0x7ff0501ca290>

In [2]:
BATCH_SIZE = 512
GAMMA = 1
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
ENV = Env(35)
NUM_ACTIONS = ENV.n_actions
NUM_STEPS = int(5e3)
print(ENV.R)

35


In [3]:
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state','next_action'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    
    def push(self, *args):
        self.memory.append(Transition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [4]:
class Q_Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_1 = nn.Linear(1, 32)
        self.fc_2 = nn.Linear(32, 32)
        self.fc_3 = nn.Linear(32, NUM_ACTIONS)

    def forward(self, x):
        x = F.relu(self.fc_1(x))
        x = F.relu(self.fc_2(x))
        x = self.fc_3(x)
        
        return x

In [5]:
policy_net = Q_Network().to(device)
target_net = Q_Network().to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.Adam(policy_net.parameters())
memory = ReplayMemory(NUM_STEPS)

steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[np.random.randint(0, 2)]], device=device, dtype=torch.long)

In [6]:
def optimizer_model(timestep=0, batch_num=0, avg_reward=0):
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    next_state_batch = torch.cat(batch.next_state)
    next_action_batch = torch.cat(batch.next_action)

    state_action_values = policy_net(state_batch).gather(1, action_batch)
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
            next_state_values = target_net(next_state_batch).gather(1, next_action_batch) 
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    writer.add_scalar("Loss", loss, timestep)
    writer.add_scalar("Avg reward (over batch)", avg_reward, timestep)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

    return loss

In [7]:
cur_state = ENV.reset()
cur_state = torch.tensor ([cur_state], dtype=torch.float32, device=device).unsqueeze(0)
cur_action = select_action(cur_state)
next_state, reward = ENV.step(cur_action.item())
next_state = torch.tensor([next_state], dtype=torch.float32, device=device).unsqueeze(0)
reward = torch.tensor([reward], device=device)
next_action = select_action(next_state)
memory.push(cur_state, cur_action, reward, next_state, next_action)

for i in tqdm(range(1, NUM_STEPS), desc="Steps", leave=False):
    cur_state = deepcopy(next_state)
    cur_action = deepcopy(next_action)
    next_state, reward = ENV.step(cur_action.item())
    next_state = torch.tensor([next_state], dtype=torch.float32, device=device).unsqueeze(0)
    reward = torch.tensor([reward], device=device)
    next_action = select_action(next_state)
    memory.push(cur_state, cur_action, reward, next_state, next_action)

    loss = optimizer_model(timestep = i, avg_reward = reward.item())

    target_net_state_dict = target_net.state_dict()
    policy_net_state_dict = policy_net.state_dict()
    for key in policy_net_state_dict:
        target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
    target_net.load_state_dict(target_net_state_dict)

    if i%100 == 0:
        SAVE_PATH = './checkpoints/DeepSARSA/DeepSARSA_{}.pt'.format(i)
        target_net.load_state_dict(policy_net.state_dict())
        torch.save(policy_net.state_dict(), SAVE_PATH)

Steps:   0%|          | 0/4999 [00:00<?, ?it/s]

  return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)


In [8]:
best_reward = -torch.inf
best_policy = Q_Network()
for i in tqdm(range(100, NUM_STEPS, 100), desc="Evaluating", leave=False):
    LOAD_PATH = f'./checkpoints/DeepSARSA/DeepSARSA_{i}.pt'
    policy_net = Q_Network().to(device)
    checkpoint = torch.load(LOAD_PATH)
    policy_net.load_state_dict(checkpoint)
    reward = evaluate_policy(ENV, policy_net)
    writer.add_scalar("Reward (over policy)", reward, i)
    if reward > best_reward:
        best_reward = reward
        best_policy = deepcopy(policy_net)

print_policy(best_policy)

Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 
1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
