In [1]:
import cv2
import os
import torch
import gym
import time
import math
import random

from collections import namedtuple 
from itertools import count

import numpy as np
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [2]:
class DataUtils():

    @staticmethod
    def save_episode(frames, episode, path):
        """
        Saves the episode as a video, returns the path of the saved video

        frames: A list of frames to be saved as the video
        episode: Episode number 
        path: directory to save the frames in
        """

        height,width,channels = frames[0].shape

        path  = os.path.join(path, str(episode)+'.mp4')
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        video = cv2.VideoWriter(path, fourcc, 24, (width, height)) # Upscale for better quality

        for frame in frames:
            video.write(frame)

        cv2.destroyAllWindows()
        video.release()

        return path

In [3]:
# Create a Named Tuple to hold our experiences 

Experience = namedtuple(
    'Experience', # Class name
    ( 
        'state', # start state
        'action', # action taken
        'reward', # reward recieved
        'next_state', # next state,
        'is_terminal' # store information about wether the state is terminal or not
    )  
)

In [4]:
class ExperienceReplayBuffer(object):
    """
    Bounded buffer to hold experiences
    """
    
    def __init__(self, size):
        
        self.max_size = size # set the maximum size limit
        self.push_count = 0 # set the push count to zero
        self.buffer = [] # set the buffer to be an empty list
        
    def push(self, experience):
        
        # check if we have space in our buffer
        if len(self.buffer) < self.max_size:
            self.buffer.append(experience) # add the experience to buffer if we have space 
        else:
            self.buffer[self.push_count % self.max_size] = experience # deque oldest experience and add new one
        
        # increment push count
        self.push_count += 1
    
    def is_samplable(self, batch_size):        
        return len(self.buffer) >= batch_size
    
    def sample(self, batch_size):

        # return a random sample from the batch 
        batch = random.sample(self.buffer, batch_size)
        batch = Experience(*zip(*batch))
        
        states = torch.cat(batch.state)
        actions = torch.cat(batch.action)
        rewards = torch.cat(batch.reward)
        next_states = torch.cat(batch.next_state)
        is_terminal = torch.cat(batch.is_terminal)
        
        return Experience(states, actions, rewards, next_states, is_terminal)

In [5]:
class EpsilonGreedyStrategy(object):
    
    def __init__(self, start, end, decay):
        
        self.start = start # set the starting value of epsilon
        self.end = end # set the ending value of epsilon
        self.decay = decay # set the decay rate of epsilon
        
    def get_exploration_rate(self, current_step):
        
        # return an exponentially decaying epsilon
        return self.end + (self.start - self.end) * math.exp(-1.0 * current_step * self.decay)
        

In [6]:
class CartPoleEnvManager():
    def __init__(self, device):
        self.device = device
        self.env = gym.make('CartPole-v0').unwrapped
        self.env.reset()
        self.current_screen = None
        self.done = False
    
    def reset(self):
        self.env.reset()
        self.current_screen = None
        
    def close(self):
        self.env.close()
        
    def render(self, mode='human'):
        return self.env.render(mode)
        
    def num_actions_available(self):
        return self.env.action_space.n
        
    def take_action(self, action):        
        _, reward, self.done, _ = self.env.step(action.item())
        return torch.tensor([reward], device=self.device)
    
    def just_starting(self):
        return self.current_screen is None
    
    def get_state(self):
        if self.just_starting() or self.done:
            self.current_screen = self.get_processed_screen()
            black_screen = torch.zeros_like(self.current_screen)
            return black_screen
        else:
            s1 = self.current_screen
            s2 = self.get_processed_screen()
            self.current_screen = s2
            return s2 - s1
    
    def get_screen_height(self):
        screen = self.get_processed_screen()
        return screen.shape[2]
    
    def get_screen_width(self):
        screen = self.get_processed_screen()
        return screen.shape[3]
       
    def get_processed_screen(self):
        screen = self.render('rgb_array').transpose((2, 0, 1)) # PyTorch expects CHW
        screen = self.crop_screen(screen)
        return self.transform_screen_data(screen)
    
    def crop_screen(self, screen):
        screen_height = screen.shape[1]
        
        # Strip off top and bottom
        top = int(screen_height * 0.4)
        bottom = int(screen_height * 0.8)
        screen = screen[:, top:bottom, :]
        return screen
    
    def transform_screen_data(self, screen):       
        # Convert to float, rescale, convert to tensor
        screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
        screen = torch.from_numpy(screen)
        
        # Use torchvision package to compose image transforms
        resize = T.Compose([
            T.ToPILImage()
            ,T.Resize((40,90))
            ,T.ToTensor()
        ])
        
        return resize(screen).unsqueeze(0).to(self.device) # add a batch dimension (BCHW)

In [7]:
class Agent():
    
    def __init__(self, strategy, num_actions, device):
        
        self.current_step = 0
        self.strategy = strategy
        self.num_actions = num_actions
        self.device = device

    def select_action(self, state, policy_net):
        
        rate = strategy.get_exploration_rate(self.current_step)
        self.current_step += 1

        if rate > random.random():
            action = random.randrange(self.num_actions)
            return torch.tensor([action]).to(self.device) # explore      
        else:
            with torch.no_grad():
                return policy_net(state).argmax(dim=1).to(self.device) # exploit 
    
    def test_agent(self, policy_net, testEnv):

        episode_frames = []
        state  = testEnv.reset_environment()
        test_reward = 0

        # go through time steps
        for t in range(config['max_test_length']):
            action = policy_net(state.to(net_device)).argmax(dim=1) # get the action
            state, reward, done, _, frame = testEnv.perform_action(action) # perform the action
            test_reward += reward.item() # add the reward
            episode_frames.append(frame.T)

            # stop the loop when done
            if done.item():
                break
    
        return test_reward, episode_frames, len(e)

In [8]:
class CartpoleDQN(nn.Module):
    
    def __init__(self, img_height, img_width, num_actions):
        super().__init__()
         
        self.fc1 = nn.Linear(in_features=img_height*img_width*3, out_features=24)   
        self.fc2 = nn.Linear(in_features=24, out_features=32)
        self.out = nn.Linear(in_features=32, out_features=num_actions)            

    def forward(self, t):
        t = t.flatten(start_dim=1)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        return t

In [9]:
class Qvalues():
    
    @staticmethod
    def get_current(policy_net, states, actions):
        """
        Returns the Q(s, a) value of the current batch
        
        policy_net: policy network
        states: batch of states of shape (B, C, H, W)
        actions: batch of actions of shape(B, 1)
        """
        
        # return the Q value of the corresponding actions
        return policy_net(states).gather(dim=1, index=actions.unsqueeze(-1))
    
    @staticmethod
    def get_next(target_net, next_states, is_terminal):
        """
        Returns the Q'(s', a') value of the current batch
        
        target_net: target network
        next_states: batch of next states with shape (B, C, H, W)
        is_terminal: boolean mask indication which states are terminal 
        """
        
        with torch.no_grad():
            q_values               = target_net(next_states).max(dim = 1)[0].detach() 
            q_values[is_terminal]  = 0
            return q_values

In [10]:
import wandb

# login to weights and biases
wandb.login()
os.environ['WANDB_NOTEBOOK_NAME'] = 'Breakout-bot-v1.ipynb' # set the name of the notebook

# set up the run configuration
config = dict(
    batch_size         = 256,
    epsilon_start      = 1,
    epsilon_end        = 0.01,
    epsilon_decay      = 0.001,
    gamma              = 0.999,
    learning_rate      = 0.001,
    episodes           = 1000,
    sync_time          = 10, 
    replay_buffer_size = 100000,
    loss_fn            = 'MSE',
    min_buffer_size    = 256,
    max_test_length    = 1500,
    fps                = 4
)

# initialize wandb
run = wandb.init(project='Hitchiker-s-Guide-to-the-Galaxy-of-Reinforcement-Learning', config=config)
artifact = wandb.Artifact('policy_net', type='model')

# choose the device to run the network on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# set the directory to save videos of current run
run_videos = os.path.join((str(run.name)))
run_checkpoints = os.path.join(run_videos, 'ckpt')

# make the checkpoint and video directories if they do not exist
if not os.path.isdir(run_checkpoints):
    os.makedirs(run_checkpoints)

[34m[1mwandb[0m: Currently logged in as: [33mpeacekurella[0m (use `wandb login --relogin` to force relogin)


In [11]:
#create the env managers
em = CartPoleEnvManager(device) 

# create the epsilon scheduler
strategy = EpsilonGreedyStrategy(config['epsilon_start'], config['epsilon_end'], config['epsilon_decay']) 

# create the agent
agent = Agent(strategy, em.num_actions_available(), device)

# create the target and policy networks
policy_net = CartpoleDQN(em.get_screen_height(), em.get_screen_width(), em.num_actions_available()).to(device)
target_net = CartpoleDQN(em.get_screen_height(), em.get_screen_width(), em.num_actions_available()).to(device) 
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
wandb.watch(policy_net)

# create the replay memory buffer
replay_memory = ExperienceReplayBuffer(config['replay_buffer_size']) 

# create the optimizer
optimizer = torch.optim.Adam(policy_net.parameters(), lr=config['learning_rate']) 

# set the loss function
loss_fn = torch.nn.MSELoss(reduction='mean') 


# keep going through the episode until it's done
for episode in range(config['episodes']): # count is a py-function that keeps track of current iteration
    
    em.reset()
    state = em.get_state()
    episode_reward = 0
    episode_loss   = 0

    for timestep in count():
        action     = agent.select_action(state, policy_net)
        reward     = em.take_action(action)
        next_state = em.get_state()
        replay_memory.push(Experience(state, action, next_state, reward, em.done))
        state      = next_state

        if replay_memory.is_samplable(config['min_buffer_size']):

            # sample the replay memory
            states, actions, rewards, next_states, terminals = replay_memory.sample(config['batch_size'])
            
            # get the current Q values
            current_q_values = Qvalues.get_current(
                policy_net,
                states.to(device),
                actions.to(device)
            )
            
            # get the next Q values
            next_q_values = Qvalues.get_next(
                target_net,
                next_states.to(device),
                terminals.to(device)
            )
            
            # calculate target Q values
            target_q_values = (next_q_values * config['gamma']) + rewards
            print(current_q_values, target_q_values.shape)
            
            # calculate loss and backprop
            loss = loss_fn(current_q_values, target_q_values.unsqueeze(1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            
        # if done,  log losses & statistics and start the next episode
        if em.done:
            run.log({
                'EpisodeLoss':   episode_loss,
                'Train/EpisodeDuration': timestep
            })
            break
        
    # sync the target and policy nets
    if episode % config['sync_time'] == 0:
        target_net.load_state_dict(policy_net.state_dict())

#     # periodically test the model and save the episode videos
#     if episode % 10 == 0:

#         # test the agent
#         test_reward, episode_frames = agent.test_agent(policy_net, testEnv)

#         # save the episode locally for back up
#         DataUtils.save_episode(episode_frames, episode, run_videos)

#         # expects the video to be of shape (t, c, h, w)
#         video = wandb.Video(
#             np.array(episode_frames).transpose(0, 3, 2, 1),
#             fps=config['fps'],
#             caption = str(episode),
#             format='mp4'
#         )
        
#         # log the video and reward
#         wandb.log({
#             'Test/Video' : video,
#             'Test/EpisodeReward': len(episode_frames)
#         }, step = episode)

#     # periodically save the model weights
#     if episode % 100 == 0:

#         # save locally for back up
#         torch.save(policy_net.state_dict(), os.path.join(run_checkpoints, str(episode)))

#         # use wandb artifact to save the model
#         artifact.add_file(os.path.join(run_checkpoints, str(episode)))

        
# close the environment
run.log_artifact(artifact)
run.join()
em.close()

RuntimeError: CUDA error: out of memory