In [1]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import models
from comet_ml import Experiment
%matplotlib inline

In [2]:
experiment = Experiment(api_key="dZm2UV8sODS5eDYysEf8TzKNu", project_name="cart")

jupyter comet_ml enable
COMET INFO: 
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/syrios/cart/d11f6aaf5bd1477b920f2670ba521ab3



In [3]:
#env = gym.make("Pong-v0")
env = gym.make('CartPole-v0').unwrapped

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [4]:
resize = T.Compose([T.ToPILImage(),
                    T.Resize(40, interpolation=Image.CUBIC),
                    T.ToTensor()])

# This is based on the code from gym.
screen_width = 600


def get_cart_location():
    world_width = env.x_threshold * 2
    scale = screen_width / world_width
    return int(env.state[0] * scale + screen_width / 2.0)  # MIDDLE OF CART


def get_screen():
    screen = env.render(mode='rgb_array').transpose(
        (2, 0, 1))  # transpose into torch order (CHW)
    # Strip off the top and bottom of the screen
    screen = screen[:, 160:320]
    view_width = 320
    cart_location = get_cart_location()
    if cart_location < view_width // 2:
        slice_range = slice(view_width)
    elif cart_location > (screen_width - view_width // 2):
        slice_range = slice(-view_width, None)
    else:
        slice_range = slice(cart_location - view_width // 2,
                            cart_location + view_width // 2)
    # Strip off the edges, so that we have a square image centered on a cart
    screen = screen[:, :, slice_range]
    # Convert to float, rescare, convert to torch tensor
    # (this doesn't require a copy)
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = torch.from_numpy(screen)
    # Resize, and add a batch dimension (BCHW)
    return resize(screen).unsqueeze(0).cuda()

## Replay Memory

Store transitions that the agent observes so can be re-used later. By sampling from this randomly improves stability

In [5]:
Transition = namedtuple('Transition',
                       ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0
        
    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity
        
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [6]:
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 100
TARGET_UPDATE = 10
num_episodes = 10000
N_ACTIONS = 2

experiment.log_parameter("batch size", BATCH_SIZE)
experiment.log_parameter("gamma", GAMMA)
experiment.log_parameter("eps start", EPS_START)
experiment.log_parameter("eps end", EPS_END)
experiment.log_parameter("eps decay", EPS_DECAY)
experiment.log_parameter("target update", TARGET_UPDATE)
experiment.log_parameter("num episodes", num_episodes)


# the policy network is used to play the game - aka actor
policy_net = models.DQN(N_ACTIONS).cuda()
# the target net is used to predict Q values for next action
# we need 2 otherwise we would be using the same network
# in the actual and predicted values of our loss function
target_net = models.DQN(N_ACTIONS).cuda()
target_net.load_state_dict(policy_net.state_dict())
# sets training to false 
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000)

steps_done = 0

def select_action(state):
    global steps_done
    # gen random number
    sample = random.random()
    # get threshold which decays from start to end 
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    # if exceed, pick best
    if sample > eps_threshold:
        # No gradients b/c not learning, just getting best one
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1,1)
    # else, random
    else:
        return torch.tensor([[random.randrange(2)]], device='cuda', 
                            dtype=torch.long)  

In [7]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return None
    
    transitions = memory.sample(BATCH_SIZE)
    # transpose the batch...
    batch = Transition(*zip(*transitions))
    # compute mask of transitions which didn't lead to ending game
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device="cuda", dtype=torch.uint8)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    # calculate the Q value of taking the state, action pairs which were taken
    # the gather basically takes the Q value for the action choosen
    # So if my input choose action 3 for the given state, that is what I would gather.
    # Basically, what is the Q value for what actually happened
    # Q value being the total expected value from taking an action given a state.
    # These are basically our predictions for learning
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    
    next_state_values = torch.zeros(BATCH_SIZE).cuda()
    # get the best actions for the next states
    # detach is for speed so don't calc gradients
    target_net_results = target_net(non_final_next_states).max(1)[0].detach()
    next_state_values[non_final_mask] = target_net_results
    # the expected value of the Q(s,a) given from the policy net is the
    # reward given plus the discounted value of the Q value from taking the best
    # action at the next step
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    
    loss = F.smooth_l1_loss(state_action_values, 
                            expected_state_action_values.unsqueeze(1))
    
    # Optimize the policy net to become better and predicting Q values
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [8]:
running_reward = None
reward_sum = 0
for i_episode in range(num_episodes):
    env.reset()
    last_screen = get_screen()
    current_screen = get_screen()
    state = current_screen - last_screen
    prev_state = None
    losses = []
    for t in count():
        action = select_action(state)
        _, reward, done, _ = env.step(action.item())
        reward_sum += reward
        reward = torch.tensor([reward])
        reward = reward.cuda()
        
        last_screen = current_screen
        current_screen = get_screen()
        
        if not done:
            next_state = current_screen - last_screen
        else:
            next_state = None
        memory.push(state, action, next_state, reward)
        state = next_state
        optimize_model()
        if done:
            running_reward = reward_sum if running_reward is None else running_reward * 0.99 + reward_sum * 0.01
            experiment.log_metric("reward sum", reward_sum, step=i_episode)
            reward_sum = 0
            experiment.log_metric("reward mean", running_reward, step=i_episode)
            break
    if i_episode % TARGET_UPDATE == 0:
        print("Episode: {}".format(i_episode))
        target_net.load_state_dict(policy_net.state_dict())

Episode: 0
Episode: 10
Episode: 20
Episode: 30
Episode: 40
Episode: 50
Episode: 60
Episode: 70
Episode: 80
Episode: 90
Episode: 100
Episode: 110
Episode: 120
Episode: 130
Episode: 140
Episode: 150
Episode: 160
Episode: 170
Episode: 180
Episode: 190
Episode: 200
Episode: 210
Episode: 220
Episode: 230
Episode: 240
Episode: 250
Episode: 260
Episode: 270
Episode: 280
Episode: 290
Episode: 300
Episode: 310
Episode: 320
Episode: 330
Episode: 340
Episode: 350
Episode: 360
Episode: 370
Episode: 380
Episode: 390
Episode: 400
Episode: 410
Episode: 420
Episode: 430
Episode: 440
Episode: 450
Episode: 460
Episode: 470
Episode: 480
Episode: 490
Episode: 500
Episode: 510
Episode: 520
Episode: 530
Episode: 540
Episode: 550
Episode: 560
Episode: 570
Episode: 580
Episode: 590
Episode: 600
Episode: 610
Episode: 620
Episode: 630
Episode: 640
Episode: 650
Episode: 660
Episode: 670
Episode: 680
Episode: 690
Episode: 700
Episode: 710
Episode: 720
Episode: 730
Episode: 740
Episode: 750
Episode: 760
Episode: 7

Episode: 5940
Episode: 5950
Episode: 5960
Episode: 5970
Episode: 5980
Episode: 5990
Episode: 6000
Episode: 6010
Episode: 6020
Episode: 6030
Episode: 6040
Episode: 6050
Episode: 6060
Episode: 6070
Episode: 6080
Episode: 6090
Episode: 6100
Episode: 6110
Episode: 6120
Episode: 6130
Episode: 6140
Episode: 6150
Episode: 6160
Episode: 6170
Episode: 6180
Episode: 6190
Episode: 6200
Episode: 6210
Episode: 6220
Episode: 6230
Episode: 6240
Episode: 6250
Episode: 6260
Episode: 6270
Episode: 6280
Episode: 6290
Episode: 6300
Episode: 6310
Episode: 6320
Episode: 6330
Episode: 6340
Episode: 6350
Episode: 6360
Episode: 6370
Episode: 6380
Episode: 6390
Episode: 6400
Episode: 6410
Episode: 6420
Episode: 6430
Episode: 6440
Episode: 6450
Episode: 6460
Episode: 6470
Episode: 6480
Episode: 6490
Episode: 6500
Episode: 6510
Episode: 6520
Episode: 6530
Episode: 6540
Episode: 6550
Episode: 6560
Episode: 6570
Episode: 6580
Episode: 6590
Episode: 6600
Episode: 6610
Episode: 6620
Episode: 6630
Episode: 6640
Episod

In [9]:
torch.save(policy_net.state_dict(), "./models/cart_dqn_policy.state")