In [1]:
import gym
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T

from wrappers import wrap_dqn, LazyFrames

In [9]:
class Net(nn.Module):
    def __init__(self, n_actions):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.bn1 = nn.BatchNorm2d(32)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm2d(64)
        
        self.value = nn.Linear(3136, 512)
        self.value_head = nn.Linear(512, n_actions)
        
        self.advantage = nn.Linear(3136, 512)
        self.advantage_head = nn.Linear(512, n_actions)
        

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(x.size(0), -1)
        
        value = self.value_head(F.relu(self.value(x)))
        advantage = self.advantage_head(F.relu(self.advantage(x)))
        
        return value + (advantage - advantage.mean())
    
class Memory:
    def __init__(self, state, action, reward, next_state, terminal):
        self.state = state
        self.action = action
        self.reward = reward
        self.next_state = next_state
        self.terminal = terminal

In [None]:
env = wrap_dqn(gym.make('PongNoFrameskip-v4'))

In [4]:
GAMMA = 0.99

EPSILON = 1
EPSILON_LOW = 0.02
EPSILON_DECAY = 10**5

TARGET_UPDATE_INTERVAL = 1000
MEMORY_SIZE = 100000
MEMORY_INIT = 10000
BATCH_SIZE = 32

q_func = Net(6)
target_func = Net(6)
target_func.load_state_dict(q_func.state_dict())

q_func.cuda()
target_func.cuda()

optimizer = optim.Adam(q_func.parameters())
loss = nn.SmoothL1Loss()
memories = deque(maxlen=MEMORY_SIZE)

In [5]:
def var(x):
    if isinstance(x, np.float64):
        x = torch.from_numpy(np.array([x]))
    if isinstance(x, LazyFrames):
        x = np.array(x).reshape(1, 4, 84, 84)
        x = torch.from_numpy(x)
    
    return Variable(x).type(torch.FloatTensor).cuda()

def select_action(x):
    if random.random() < EPSILON:
        return env.action_space.sample()
    else:
        return np.argmax(q_func(var(x)).data.cpu().numpy(), 1)[0]

In [6]:
def optimize():
    optimizer.zero_grad()

    batch = np.random.choice(memories, min(BATCH_SIZE, len(memories)))

    for memory in batch:
        next_action = np.argmax(q_func(var(memory.next_state)).data.cpu().numpy(), 1)[0]
        target = var(memory.reward + (GAMMA * target_func(var(memory.next_state)).data[0, next_action] if not memory.terminal else 0))
        output = q_func(var(memory.state))[0, memory.action]

        loss(output, target).backward()

    optimizer.step()

In [7]:
#frame_count = 0
observation = env.reset()
total_reward = 0

while True:
    frame_count += 1
    #env.render()

    previous_obs = observation
    action = select_action(observation)

    observation, reward, done, info = env.step(action)  
    memories.append(Memory(previous_obs, action, reward, observation, done))

    total_reward += reward
    if done:
        print(frame_count, total_reward)
        total_reward = 0
        observation = env.reset()
        
    if frame_count < MEMORY_INIT:
        continue
            
    EPSILON = max(EPSILON_LOW, EPSILON - frame_count / EPSILON_DECAY)
    
    optimize()

    if frame_count % TARGET_UPDATE_INTERVAL == 0:
        target_func.load_state_dict(q_func.state_dict())

[33mWARN: <class '__main__.FrameStack'> doesn't implement 'reset' method, but it implements deprecated '_reset' method.[0m
896 -21.0
1795 -20.0
3032 -19.0
3972 -21.0
4839 -20.0
5802 -20.0
6710 -21.0
7531 -21.0
8316 -21.0
9229 -21.0
10141 -20.0
10903 -21.0
11690 -21.0
12445 -21.0
13200 -21.0
13961 -21.0
14888 -20.0
15707 -21.0
16541 -21.0
17481 -20.0
18364 -21.0
19360 -21.0
21136 -16.0
22569 -19.0
23899 -18.0
25310 -20.0
26818 -20.0
28780 -17.0
31325 -12.0
33173 -15.0
35547 -15.0
38129 -13.0
40183 -15.0
43240 -4.0
45350 -16.0
47553 -12.0
50107 -10.0
52558 -14.0
54881 -13.0
57597 -13.0
59370 -14.0
61253 -15.0
63618 -12.0
67182 -6.0
71106 -2.0
74385 -10.0
77807 -5.0
81360 -3.0
84456 -10.0
88194 -3.0
90922 -9.0
93775 -7.0
97091 -7.0
100716 -3.0
103952 4.0
107030 6.0
110621 4.0
113709 7.0
116885 -1.0
119665 -6.0
123155 -1.0
126662 2.0
129211 10.0
131336 14.0
134349 10.0
137819 -2.0
141029 -7.0
143624 12.0
146806 6.0
149403 8.0
152637 4.0
155076 14.0
157533 12.0
160541 -4.0
163475 12.0
166

KeyboardInterrupt: 

In [8]:
import time

while True:
    env.render()
    time.sleep(0.025)
    previous_obs = observation
    action = select_action(observation)

    observation, reward, done, info = env.step(action)  

    if done:
        observation = env.reset()

KeyboardInterrupt: 

In [10]:
torch.save(q_func.state_dict(), 'model_state')