In [None]:
from collections import namedtuple, deque
from copy import deepcopy
from itertools import count
import math
import random

from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from env import Env

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
print(device)

writer = SummaryWriter("runs/deep-sarsa")
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

ENV = Env(4.5)

In [None]:
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
NUM_ACTIONS = ENV.n_actions
NUM_STEPS = int(2e4)

In [None]:
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state','next_action'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    
    def push(self, *args):
        self.memory.append(Transition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [None]:
class Q_Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_1 = nn.Linear(1, 32)
        self.fc_2 = nn.Linear(32, 32)
        self.fc_3 = nn.Linear(32, NUM_ACTIONS)

    def forward(self, x):
        x = F.relu(self.fc_1(x))
        x = F.relu(self.fc_2(x))
        x = self.fc_3(x)
        
        return x

In [None]:
policy_net = Q_Network().to(device)
target_net = Q_Network().to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.Adam(policy_net.parameters())
memory = ReplayMemory(NUM_STEPS)

steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[np.random.randint(0, 2)]], device=device, dtype=torch.long)

In [None]:
def optimizer_model(timestep=0, batch_num=0, avg_reward=0):
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    next_state_batch = torch.cat(batch.next_state)
    next_action_batch = torch.cat(batch.next_action)

    state_action_values = policy_net(state_batch).gather(1, action_batch)
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
            next_state_values = target_net(next_state_batch).gather(1, next_action_batch) 
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    writer.add_scalar("Loss", loss, timestep)
    writer.add_scalar("Avg reward", avg_reward, timestep)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

    return loss

In [None]:
cur_state = ENV.reset()
cur_state = torch.tensor ([cur_state], dtype=torch.float32, device=device).unsqueeze(0)
cur_action = select_action(cur_state)
next_state, reward = ENV.step(cur_action.item())
next_state = torch.tensor([next_state], dtype=torch.float32, device=device).unsqueeze(0)
reward = torch.tensor([reward], device=device)
next_action = select_action(next_state)
memory.push(cur_state, cur_action, reward, next_state, next_action)

best_loss = torch.inf

for i in range(1, NUM_STEPS):
    cur_state = deepcopy(next_state)
    cur_action = deepcopy(next_action)
    next_state, reward = ENV.step(cur_action.item())
    next_state = torch.tensor([next_state], dtype=torch.float32, device=device).unsqueeze(0)
    reward = torch.tensor([reward], device=device)
    next_action = select_action(next_state)
    memory.push(cur_state, cur_action, reward, next_state, next_action)

    loss = optimizer_model(timestep = i, avg_reward = reward.item())

    target_net_state_dict = target_net.state_dict()
    policy_net_state_dict = policy_net.state_dict()
    for key in policy_net_state_dict:
        target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
    target_net.load_state_dict(target_net_state_dict)

    if i%100 == 0:
        SAVE_PATH = './checkpoints/DeepSARSA_{}.pt'.format(i)
        target_net.load_state_dict(policy_net.state_dict())
        torch.save(policy_net.state_dict(), SAVE_PATH)
    if loss is not None and loss.item() < best_loss:
        best_loss = loss.item()
        for s in range(1, 25):
            inp = torch.tensor ([s], dtype=torch.float32, device=device).unsqueeze(0)
            print(policy_net(torch.tensor(inp)).max(1)[1].item(), end = " ")
            SAVE_PATH = './checkpoints/DeepSARSA_best_model.pt'
        print()