In [1]:
import gym
import math
import random
import numpy as np
# import matplotlib
# import matplotlib.pyplot as plt
from collections import namedtuple, deque
# from itertools import count
# from PIL import Image

from estimator import cost_estimate
from env import env

import torch
import torch.nn as nn
import torch.optim as optim
# import torch.nn.functional as F
# import torchvision.transforms as T


# env = gym.make('CartPole-v0').unwrapped

# # set up matplotlib
# is_ipython = 'inline' in matplotlib.get_backend()
# if is_ipython:
#     from IPython import display

# plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [3]:
n_states=62
n_actions=62
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            # nn.Linear(np.array(env.observation_space.shape).prod(), 128, bias=False),
            nn.Linear(n_states, 128, bias=False),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, n_actions),
        )
    def forward(self, x):
        # x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [4]:
BATCH_SIZE = 32
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 0.997
# EPS_DECAY = 200
TARGET_UPDATE = 10
learning_rate=5e-4
np.random.seed(0)


policy_net=DQN().to(device)
target_net=DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(),learning_rate)
memory = ReplayMemory(10000)

steps_done = 0
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done * EPS_DECAY)
        # math.exp(-1. * steps_done / EPS_DECAY)

    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return policy_net(torch.FloatTensor(state).unsqueeze(0)).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)


episode_durations = []


In [6]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))
    state_batch = torch.stack(batch.state).to(device=device)
    action_batch = torch.stack(batch.action).to(device=device)
    reward_batch = torch.stack(batch.reward).to(device=device)
    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    try:                                     
        non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    except :
        print(type(batch.next_state[0]))  
        # exit()                           
    # try:               
    #     state_batch = torch.cat(batch.state)
    # except:
    #     print(batch.state)
    # action_batch = torch.cat(batch.action)
    # reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [7]:
num_episodes = 5 #600
t_max=5
TARGET_UPDATE = 2
from env import env

db_env=env()

for i_episode in range(num_episodes):
    # Initialize the environment and state
    db_env.reset()
    current_state=None
    current_state =db_env.get_current_state()
    for t in range(t_max):
        # Select and perform an action
        # TODO reduce actions when generating or stepping
        action = select_action(current_state)
        next_state = db_env.step(action.item())
        
        # reward = torch.tensor([reward], device=device) 
        reward=cost_estimate(db_env)

        # Store the transition in memory
        memory.push(current_state, action, next_state, reward)

        # # Move to the next state
        current_state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()
        # print(' {} completed'.format(t))

        # if done:
        #     episode_durations.append(t + 1)
        #     # plot_durations()
        #     break
    # Update the target network, copying all weights and biases in DQN
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
    print('episode {} completed'.format(i_episode))
print('Complete')
# env.render()
# env.close()
# plt.ioff()
# plt.show()

 0 completed
 1 completed
 2 completed
 3 completed
 4 completed
episode 0 completed
 0 completed
 1 completed
 2 completed
 3 completed
 4 completed
episode 1 completed
 0 completed
 1 completed
 2 completed
 3 completed
 4 completed
episode 2 completed
 0 completed
 1 completed
 2 completed
 3 completed
 4 completed
episode 3 completed
 0 completed
 1 completed
 2 completed
 3 completed
 4 completed
episode 4 completed
Complete
