In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import namedtuple, deque
import random
from env import Env
import numpy as np
import wandb
import copy

# wandb setup
# number = 2
# NAME = "DQN_RL" + str(number)
# ID = "DQN_RL" + str(number)
# run = wandb.init(project='DQN_RL', name = NAME, id = ID)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
wandb.init(
    # set the wandb project where this run will be logged
    project="DQN_RL"

)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Network error (JSONDecodeError), entering retry loop.


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016752189583333323, max=1.0…

wandb: Network error (JSONDecodeError), entering retry loop.


Problem at: /var/folders/69/4nr0lmbs6h38d0yjvwqsk5b15ycj_r/T/ipykernel_72523/1326559565.py 1 <cell line: 1>


CommError: Run initialization has timed out after 60.0 sec. 
Please refer to the documentation for additional information: https://docs.wandb.ai/guides/track/tracking-faq#initstarterror-error-communicating-with-wandb-process-

In [2]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state', 'next_action'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [3]:
class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 16)
        self.layer2 = nn.Linear(16, 16)
        self.layer3 = nn.Linear(16, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [4]:
torch.cuda.empty_cache()
env = Env(R=5)

In [5]:
BATCH_SIZE = 512
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
REPLAY_MEMORY_SIZE = 10000

n_actions = env.n_actions
state = env.reset()
n_observations = 1

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.Adam(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(REPLAY_MEMORY_SIZE)


steps_done = 0


def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        np.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[np.random.randint(0, 2)]], device=device, dtype=torch.long)

In [6]:
def optimize_model(timestep=0, batch_num=0, reward=0):
    if len(memory) < BATCH_SIZE:
        return 

    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    next_state_batch = torch.cat(batch.next_state)
    next_action_batch = torch.cat(batch.next_action)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values = target_net(next_state_batch).gather(1, next_action_batch).max(1)[0]
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # wandb.log({'loss': loss, 'timestep': timestep, 'batch': batch_num})
    # wandb.log({'loss': loss, 'reward': reward, 'timestep': timestep}) #, 'batch': t})

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

    return loss


In [7]:
num_episodes = 20000
num_time_per_episode = 500

# wandb.config.update({
#     'max_timesteps': num_episodes*num_time_per_episode,
#     'batch_size': BATCH_SIZE,
#     'optimizer': 'Adam',
#     'learning_rate': 'default',
#     'replay_memory': REPLAY_MEMORY_SIZE, # 10000
#     'n_actions': n_actions,
#     'n_observations': n_observations
# })


In [8]:

for i in range(num_episodes):
    prev_state = 0
    prev_action = 0
    prev_reward = 0
    for j in range(num_time_per_episode):
        cur_state = env.reset()
        cur_state = torch.tensor ([cur_state], dtype=torch.float32, device=device).unsqueeze(0)
        cur_action = select_action(cur_state)
        print(cur_action.item())
        next_state, reward = env.step(cur_action.item())

        reward = torch.tensor([reward], device=device)

        if j>0:
            memory.push(prev_state, prev_action, prev_reward, cur_state, cur_action)

        prev_state = copy.deepcopy(cur_state)
        prev_action = copy.deepcopy(cur_action)
        prev_reward = copy.deepcopy(reward)

        optimize_model(timestep=(i*num_time_per_episode) + j, avg_reward=reward.item())

        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if i%100 == 0:
            SAVE_PATH = './checkpoints/DeepSARSA_{}.pt'.format(i)
            target_net.load_state_dict(policy_net.state_dict())
            torch.save(policy_net.state_dict(), SAVE_PATH)

        torch.cuda.empty_cache()


0
0
0
0
0
0
0
1
1
0
1
1
1
1
0
1
0
1
1
1
0
0
0
0
1
1
0
0
0
0
1
1
0
1
1
0
0
0
0
0
1
1
0
1
0
1
0
0
0
0
0
0
0
0
0
1
0
1
1
1
0
1
0
0
1
1
1
0
1
0
1
1
0
1
1
0
0
1
1
0
0
1
1
0
0
1
0
0
1
1
0
0
1
0
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
1
1
1
1
0
0
0
1
0
0
0
0
1
0
0
0
0
1
0
1
1
0
1
1
0
1
1
1
0
0
1
1
0
0
0
1
0
0
0
1
0
1
1
1
0
0
0
1
1
0
0
1
1
0
0
0
0
1
1
1
1
1
0
0
1
0
1
1
0
0
0
1
0
1
0
0
0
1
0
0
0
1
0
0
1
1
1
0
0
1
0
1
0
0
1
0
0
1
0
0
0
0
0
0
0
0
0
0
0
1
0
0
0
0
0
0
0
1
0
0
0
1
1
0
0
0
0
1
1
1
0
1
0
0
1
0
0
1
1
0
0
0
1
0
0
0
0
0
1
0
0
0
0
0
0
1
0
0
0
0
1
1
1
0
0
0
0
0
0
0
0
0
0
0
0
1
0
0
1
1
1
0
1
1
0
0
1
0
1
0
0
1
0
1
1
0
1
0
1
0
1
0
1
0
0
1
0
0
0
0
1
0
1
0
1
0
0
1
1
0
0
0
0
0
0
0
0
0
1
0
1
0
1
0
0
0
0
0
0
0
1
1
0
0
0
1
0
0
0
0
0
1
0
1
0
1
0
0
0
1
1
0
1
0
0
0
1
0
1
0
0
1
0
0
0
0
1
1
0
0
0
0
1
0
0
0
0
1
0
1
1
0
1
1
0
1
0
0
0
1
0
0
0
1
1
0
0
0
0
1
0
0
1
0
0
0
0
1
1
0
0
1
0
0
0
1
0
1
0
1
0
1
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
0
0
0
0
0
0
0
0
0
1
1
0
0
0
0
0
0
0
1
0
0
0
0
1
0
1
0
1
0
1
0
0
0
0
0
0
0
1
0
0
0
0
0


KeyboardInterrupt: 