In [1]:
import gym
import numpy as np
import argparse
from itertools import count
from collections import deque

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from drawnow import drawnow
import matplotlib.pyplot as plt

In [2]:
last_score_plot = [0]
avg_score_plot = [0]

In [3]:
device = torch.device("cuda:02")

In [4]:
def draw_fig():
    plt.title('reward')
    plt.plot(last_score_plot, '-')
    plt.plot(avg_score_plot, 'r-')

In [5]:
class FLAGS():
    lr = 1e-3
    gamma = 0.9
    epsilon_start = 0.9
    epsilon_end = 0.05
    target_update = 10
    batch_size = 512
    max_episode = 512

In [6]:
cfg = FLAGS()

In [7]:
env = gym.make('CartPole-v0')

In [8]:
class Memory(object):
    def __init__(self, memory_size=10000):
        self.memory = deque(maxlen=memory_size)
        self.memory_size = memory_size
        
    def __len__(self):
        return len(self.memory)
    
    def append(self, item):
        self.memory.append(item)
        
    def sample_batch(self, batch_size):
        idx = np.random.permutation(len(self.memory))[:batch_size]
        return [self.memory[i] for i in idx]

In [9]:
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(4, 64)
        self.fc3 = nn.Linear(64, 2)
    def forward(self, x):
        x = F.elu(self.fc1(x))
        x = self.fc3(x)
        return x

In [10]:
def get_action(state, epsilon):
    with torch.no_grad():
        greedy_action = torch.argmax(policy_net(state.to(device)), dim=1).item()
        random_action = np.random.randint(0, 2)
    return random_action if np.random.rand() < epsilon else greedy_action

In [11]:
def update_network(states, actions, next_states, rewards, dones):
    state_action_values = policy_net(states.to(device)).gather(1, actions[:, None].long().to(device)).squeeze().to(device)
    next_state_values = torch.max(target_net(next_states.to(device)), dim=1)[0].detach()
    expected_state_action_values = rewards.to(device) + next_state_values * (1 - dones) * cfg.gamma

    expected_state_action_values = expected_state_action_values.to(device)

    loss = F.mse_loss(state_action_values, expected_state_action_values)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [12]:
policy_net = DQN().to(device)
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())

<All keys matched successfully>

In [13]:
optimizer = optim.RMSprop(policy_net.parameters(), lr=cfg.lr, weight_decay=1e-4)
memory = Memory(10000)

In [14]:
for i in range(cfg.max_episode):
    episode_durations = 0
    state = env.reset()
    epsilon = (cfg.epsilon_end - cfg.epsilon_start) * (i / cfg.max_episode) + cfg.epsilon_start

    for t in count():
        action = get_action(torch.tensor(state).float()[None, :], epsilon)
        next_state, reward, done, _ = env.step(action)

        memory.append([state, action, next_state, reward, done])
        state = next_state

        if len(memory) > cfg.batch_size:
            states, actions, next_states, rewards, dones = \
            map(lambda x: torch.tensor(x).float(), zip(*memory.sample_batch(cfg.batch_size)))

            update_network(states, actions, next_states, rewards, dones)

        if done:
            episode_durations = t + 1
            avg_score_plot.append(avg_score_plot[-1] * 0.99 + episode_durations * 0.01)
            last_score_plot.append(episode_durations)
            #drawnow(draw_fig)
            break

    # Update the target network
    if i % cfg.target_update == 0:
        target_net.load_state_dict(policy_net.state_dict())
        
        
    if i % 10 == 0:
        print(i)

0
10
20


RuntimeError: expected device cuda:2 but got device cpu

In [None]:
drawnow(draw_fig)