## DQN Tutorial

In [6]:
%matplotlib inline

In [7]:
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

In [29]:
from collections import namedtuple
from itertools import count

In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

In [31]:
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Bit flipping environment

In [68]:
class BitFlipEnv():
    
    def __init__(self, n = 8):
        self.n = n
        self.init_state = torch.randint(2, size=(n,))
        self.target_state = torch.randint(2, size=(n,))
        while np.array_equal(self.init_state, self.target_state):
            self.target_state = torch.randint(2, size=(n,))
        self.curr_state = self.init_state.clone()
        
    def step(self, action):
        self.curr_state[action] = 1 - self.curr_state[action]
        if np.array_equal(self.curr_state, self.target_state):
            return self.curr_state.clone(), 0
        else:
            return self.curr_state.clone(), -1
        
    def reset(self):
        self.init_state = torch.randint(2, size=(self.n,))
        self.target_state = torch.randint(2, size=(self.n,))
        while np.array_equal(self.init_state, self.target_state):
            self.target_state = torch.randint(2, size=(self.n,))
        self.curr_state = self.init_state.clone()

In [69]:
env = BitFlipEnv(n=10)

In [70]:
print('Initial state:', env.init_state)
print('Target state:', env.target_state)
curr_state, reward = env.step(9)
print('State, reward after taking action 9:', curr_state, reward)
print(type(curr_state))
env.reset()

Initial state: tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  1.])
Target state: tensor([ 1.,  1.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  0.])
State, reward after taking action 9: tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.]) -1
<class 'torch.Tensor'>


## Replay Memory

In [71]:
Transition = namedtuple('Transition', 
                       ('state', 'action', 'next_state', 'reward', 'goal'))

class ReplayMemory(object):
    
    def __init__(self, capacity = 1e5):
        self.capacity = capacity
        self.memory = []
    
    def push(self, *args):
        """Saves a transition which should contain:
        - current state
        - action taken
        - next state
        - reward obtained
        - goal state"""
        self.memory.append(Transition(*args))
        if len(self.memory) > self.capacity:
#             print('!!!!!exceeded!')
            del self.memory[0]

    def sample(self, batch_size):
        """
        Returns batch_size number of samples from the replay memory
        """
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [72]:
replay = ReplayMemory(2)
print(replay.capacity)
replay.push(curr_state, 9, curr_state, 0, curr_state)
print(len(replay))
print(replay.memory)
curr_state, reward = env.step(9)
replay.push(curr_state, 9, curr_state, 1, curr_state)
print(len(replay))
print(replay.memory)
curr_state, reward = env.step(1)
replay.push(curr_state, 1, curr_state, 2, curr_state)
print(len(replay))
print(replay.memory)

replay.sample(1)

2
1
[Transition(state=tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.]), action=9, next_state=tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.]), reward=0, goal=tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.]))]
2
[Transition(state=tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.]), action=9, next_state=tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.]), reward=0, goal=tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.])), Transition(state=tensor([ 0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.]), action=9, next_state=tensor([ 0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.]), reward=1, goal=tensor([ 0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.]))]
2
[Transition(state=tensor([ 0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.]), action=9, next_state=tensor([ 0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.]), reward=1, goal=tensor([ 0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.])), Transition(state=tensor([ 0.,  0.,  0.,  1.,  0.,  0.

[Transition(state=tensor([ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.]), action=1, next_state=tensor([ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.]), reward=2, goal=tensor([ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.]))]

## Q-Network in pytorch

In [73]:
class DQN(nn.Module):
    
    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)
        self.head = nn.Linear(448, 2)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))

## Feedforward network

In [74]:
NUM_BITS = 10
HIDDEN_SIZE = 256
OUTPUT = 1 

In [75]:
class FNN(nn.Module):
    
    def __init__(self):
        super(FNN, self).__init__()
        self.ln1 = nn.Linear(NUM_BITS*2, HIDDEN_SIZE)
        self.ln2 = nn.Linear(HIDDEN_SIZE, NUM_BITS)
        
    def forward(self, x):
        x = F.relu(self.ln1(x))
        x = self.ln2(x)
        return x
    

## Input extraction

## Hyper parameters

In [76]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10
MODEL_PATH = 'her_policy_net.pt'
WEIGHTS_PATH = 'her_policy_net_weights.pt'

## Initialize Model

In [77]:
policy_net = FNN().to(device)
target_net = FNN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000)

steps_done = 0

## Helper functions

In [107]:
def select_action(state, goal):
    global steps_done
    sample = random.random()
    state_goal = torch.cat((state, goal))
#     print(state)
#     print(goal)
#     print(state_goal)
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state_goal).argmax().view(1,1)
    else: 
        return torch.tensor([[random.randrange(NUM_BITS)]], device=device, dtype=torch.long)

episode_durations = []

def plot_durations():
    plt.figure(2)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())
        
    plt.pause(0.001)
    if is_ipython:
        display.clear_output(wait=True)
        display.display(plt.gcf())

In [108]:
# select_action(env.init_state, env.target_state)

In [109]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, 
                                           batch.next_state)), device=device, dtype=torch.uint8)
    non_final_next_states = torch.cat([s for s in batch.next_state 
                                      if s is not None])
    
    
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    goal_batch = torch.cat(batch.goal)
    state_goal_batch = torch.cat((state_batch, goal_batch))
    non_final_next_states_goal = torch.cat((non_final_next_states, goal_batch))
    
    state_action_values = policy_net(state_goal_batch).gather(1, action_batch)
    
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states_goal).max(1)[0].detach()
    
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [110]:
num_episodes = 1000
env = BitFlipEnv(NUM_BITS)
for i_episode in range(num_episodes):
    env.reset()
    state = env.init_state
    goal = env.target_state
    for t in count():
        action = select_action(state, goal)
        next_state, reward, = env.step(action.item())
        reward = torch.tensor([reward], device=device)
            
        memory.push(state, action, next_state, reward, goal)
        
        state = next_state
        
        optimize_model()
#         if done:
#             episode_durations.append(t + 1)
#             plot_durations()
#             break
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

## save weights
torch.save(policy_net.state_dict(), WEIGHTS_PATH)
print('Weights saved')
torch.save(policy_net, MODEL_PATH)
print('Model saved')

print('Complete')
env.render()
env.close()
plt.ioff()
plt.show()
                

RuntimeError: size mismatch, m1: [1 x 2560], m2: [20 x 256] at /opt/conda/conda-bld/pytorch-cpu_1524582300956/work/aten/src/TH/generic/THTensorMath.c:2033

In [106]:
state_goal = torch.cat((env.init_state, env.target_state))
with torch.no_grad():
    print(policy_net(state_goal).argmax())

tensor(7)
