In [1]:
import cv2
import os
import torch
import gym
import time
import math
import random

from collections import namedtuple 
from itertools import count

import numpy as np
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [2]:
class EnvironmentManager(object):
    """
    Wrapper class to manage the Environment.
    """
    
    def __init__(self, device):
        
        # set the environment
        self.env = gym.make("BreakoutDeterministic-v4")
        
        # set the device
        self.device = device
        
        # set the current state 
        self.current_state = []
        
        # set the number of actions
        self.num_actions = self.env.action_space.n
        
    def perform_action(self, action):
        """
        Performs the action and returns the preprocessed
        state
        
        action: action to perform, choose index from ['NOOP', 'FIRE', 'RIGHT', 'LEFT']
        """
        
        # pass the action to the environment
        frame, reward, done, info = self.env.step(action.item())
        
        # preprocess the observation, reward
        observation         = self.preprocess_state(frame)
        reward_clipped      = self.preprocess_reward(reward)
        done                = torch.tensor([done]).to(self.device)
        
        # update the current state
        self.current_state = torch.cat([self.current_state[:, 1:], observation], axis=1)
        
        # return the current state, reward, done, info, reward in the same order
        return self.current_state, reward_clipped, done, info, frame, reward 
        
    def preprocess_reward(self, reward):
        """
        Preprocesses the reward
        
        reward: 
        """
        
        # clip the rewards to prevent gradient explosion
        reward = np.sign(reward)
        return torch.tensor([reward]).to(self.device)
        
        
    def preprocess_state(self, state):
        """
        Preprocesses the state by downsampling it and 
        then converting it to a grayscale image
        
        state: 
        """
        
        # set the target shape to be half that of original 
        target_shape = (int(state.shape[0] / 2), int(state.shape[1] / 2))
        
        # compose the transforms 
        transform = T.Compose([
            T.ToTensor(), # convert state to a tensor
            T.Grayscale(), # convert state to Grayscale 
            T.Resize(target_shape) # reshape the image to the target shape
        ])
        
        # return the normalized image in (B, C, H, W) format
        return transform(state).unsqueeze(0).to(self.device) / 255
    
    
    def reset_environment(self):
        """
        Resets the environment and returns the starting state
        """
        
        # reset the current env
        obs = self.env.reset()
        
        # preprocess the state
        obs = self.preprocess_state(obs)
        
        # set the current state to be the stack of starting positions
        self.current_state = torch.cat([obs, obs, obs, obs], axis=1)
        
        # return the current state
        return self.current_state
    
    def close_environment(self):
        """
        Close the environment
        """
        
        # close the environment
        self.env.close()

In [3]:
class DataUtils():

    @staticmethod
    def display_state(state):
        """
        Displays the preprocessed state

        state: numpy array of shape (B, C, H, W)
        """
        fig, axes = plt.subplots(state.shape[0], state.shape[1])
        for i in range(state.shape[0]):
            for j in range(state.shape[1]):
                axes[(i*j) + j].imshow(state[0, i], cmap='gray')
        plt.show()

    @staticmethod
    def save_episode(frames, episode, path):
        """
        Saves the episode as a video, returns the path of the saved video

        frames: A list of frames to be saved as the video
        episode: Episode number 
        path: directory to save the frames in
        """

        height,width,channels = frames[0].shape

        path  = os.path.join(path, str(episode)+'.mp4')
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        video = cv2.VideoWriter(path, fourcc, 24, (width, height)) # Upscale for better quality

        for frame in frames:
            video.write(frame)

        cv2.destroyAllWindows()
        video.release()

        return path

In [4]:
# Create a Named Tuple to hold our experiences 

Experience = namedtuple(
    'Experience', # Class name
    ( 
        'state', # start state
        'action', # action taken
        'reward', # reward recieved
        'next_state', # next state,
        'is_terminal' # store information about wether the state is terminal or not
    )  
)

In [5]:
class ExperienceReplayBuffer(object):
    """
    Bounded buffer to hold experiences
    """
    
    def __init__(self, size):
        
        self.max_size = size # set the maximum size limit
        self.push_count = 0 # set the push count to zero
        self.buffer = [] # set the buffer to be an empty list
        
    def add_experience(self, experience):
        """
        Adds an experience to the buffer
        
        experience: An experience class namedtuple 
        """
        
        # check if we have space in our buffer
        if len(self.buffer) < self.max_size:
            self.buffer.append(experience) # add the experience to buffer if we have space 
        else:
            self.buffer[self.push_count % self.max_size] = experience # deque oldest experience and add new one
        
        # increment push count
        self.push_count += 1
    
    def is_samplable(self, batch_size):
        """
        Check if we can sample experiences
        
        batch_size: The batch size of samples
        """
        
        return len(self.buffer) >= batch_size
    
    def sample(self, batch_size):
        """
        Samples a random set of experiences
        
        batch_size: sample size
        """
        
        # return a random sample from the batch 
        batch = random.sample(self.buffer, batch_size)
        batch = Experience(*zip(*batch))
        
        states = torch.cat(batch.state)
        actions = torch.cat(batch.action)
        rewards = torch.cat(batch.reward)
        next_states = torch.cat(batch.next_state)
        is_terminal = torch.cat(batch.is_terminal)
        
        return (states, actions, rewards, next_states, is_terminal)

In [6]:
class EpsilonGreedyStrategy(object):
    """
    Defines the strategy for exploration vs exploitation
    """
    
    def __init__(self, start, end, decay):
        
        self.start = start # set the starting value of epsilon
        self.end = end # set the ending value of epsilon
        self.decay = decay # set the decay rate of epsilon
        
    def get_epsilon(self, current_step):
        """
        Returns the epsilon value for the current step
        
        current_step: Current step number
        """
        
        # return an exponentially decaying epsilon
        return self.end + (self.start - self.end) * math.exp(-1.0 * current_step * self.decay)
        

In [7]:
class Agent(object):
    """
    Defines the agent 
    """
    
    def __init__(self, num_actions, strategy, device):
        
        self.steps = 0 # set the number of steps taken
        self.strategy = strategy # set the exploration strategy
        self.num_actions = num_actions # set the number of actions possible
        self.device = device # set the device 
        
    def choose_action(self, policy_net, state):
        """
        Chooses an action depending on the strategy
        
        policy_net: policy network
        state: Agent's state in the env
        """
        
        # if random value is less than epsilon, then explore
        if random.random() < self.strategy.get_epsilon(self.steps):
            action = random.randrange(self.num_actions) # perform a random action to explore env
            self.steps += 1
            return torch.tensor([action]).to(self.device)
            
        # return the action with the max Q value
        self.steps += 1
        return policy_net(state).argmax(dim = 1).to(self.device) 

In [8]:
class Qnetwork(nn.Module):
    """
    Defines our Q network 
    """
    
    def __init__(self, num_actions):
        
        super(Qnetwork, self).__init__() # call the super class constructor
        self.num_actions = num_actions # set the number of possible actions
        
        self.conv1 = nn.Conv2d(4, 16, kernel_size = 8, stride = 4) # 4 channels in, 16 channels out
        self.conv2 = nn.Conv2d(16, 32, kernel_size = 4, stride = 2) # 16 channels in, 32 channels out
        
        self.fc1 = nn.Linear(32 * 11 * 8, 256)
        self.fc2 = nn.Linear(256, num_actions)
        
    def forward(self, state):
        """
        defines the forward pass through the network
        
        state: Input state to the network
        """
        
        out = self.conv1(state) # pass through first conv layer
        out = F.relu(out) # apply the relu activation
        
        out = self.conv2(out) # pass through second conv layer
        out = F.relu(out) # apply the relu activation
        
        out = out.view(-1, 32 * 11 * 8) # flatten the conv output
        
        out = self.fc1(out) # pass through first fc layer
        out = F.relu(out) # apply the relu activation
        
        out = self.fc2(out) # pass through second fc layer
        out = F.relu(out) # apply the relu activation
        
        return out

In [9]:
class Qvalues():
    
    @staticmethod
    def get_q_values(policy_net, states, actions):
        """
        Returns the Q(s, a) value of the current batch
        
        policy_net: policy network
        states: batch of states of shape (B, C, H, W)
        actions: batch of actions of shape(B, 1)
        """
        
        # return the Q value of the corresponding actions
        return policy_net(states).gather(dim=1, index=actions.unsqueeze(-1))
    
    @staticmethod
    def get_q_prime_values(target_net, next_states, is_terminal):
        """
        Returns the Q'(s', a') value of the current batch
        
        target_net: target network
        next_states: batch of next states with shape (B, C, H, W)
        is_terminal: boolean mask indication which states are terminal 
        """
        
        q_values = target_net(next_states) # get the Q values of the batch 
        q_values[is_terminal] = 0 # zero out the Q values of terminal states
        return q_values.max(axis=1).values.reshape(-1, 1) # return the max Q value

In [10]:
import wandb

# login to weights and biases
wandb.login()
os.environ['WANDB_NOTEBOOK_NAME'] = 'Breakout-bot-v1.ipynb' # set the name of the notebook

# set up the run configuration
config = dict(
    batch_size         = 32,
    epsilon_start      = 1,
    epsilon_end        = 0.1,
    epsilon_decay      = 1e-4,
    gamma              = 0.99,
    learning_rate      = 0.0001,
    episodes           = 1000000,
    sync_time          = 1000, 
    replay_buffer_size = 40000,
    loss_fn            = 'Huber'
)

# initialize wandb
run = wandb.init(project='Hitchiker-s-Guide-to-the-Galaxy-of-Reinforcement-Learning', config=config)
artifact = wandb.Artifact('policy_net', type='model')

net_device = 'cuda:0' # device to put the networks on
mem_device = 'cpu:0' # device to put the agent, env, and replay buffer on

# set the directory to save videos of current run
run_videos = os.path.join((str(run.name)))
run_checkpoints = os.path.join(run_videos, 'ckpt')

# make the checkpoint and video directories if they do not exist
if not os.path.isdir(run_checkpoints):
    os.makedirs(run_checkpoints)

[34m[1mwandb[0m: Currently logged in as: [33mpeacekurella[0m (use `wandb login --relogin` to force relogin)


In [None]:
#create the env manager on cpu to save gpu memory
envMan = EnvironmentManager(mem_device) 

# create the epsilon scheduler
strategy = EpsilonGreedyStrategy(config['epsilon_start'], config['epsilon_end'], config['epsilon_decay']) 

# create the agent
agent = Agent(envMan.num_actions, strategy, mem_device) 

policy_net = Qnetwork(envMan.num_actions).to(net_device) # create the policy net
target_net = Qnetwork(envMan.num_actions).to(net_device) # create the target net
target_net.load_state_dict(policy_net.state_dict())  # sync the weights of target and policy nets
target_net.eval() # put the target network in eval mode to make sure we dont change it's weights accidentally

# create the replay memory buffer
replay_memory = ExperienceReplayBuffer(config['replay_buffer_size']) 

# create the optimizer
optimizer = torch.optim.Adam(policy_net.parameters(), lr=config['learning_rate']) 

# set the loss function as Huber loss
loss_fn = torch.nn.SmoothL1Loss(reduction='mean') 


# Go through all the episodes
for episode in range(config['episodes']):
    
    # reset the state at the start of every episode
    state = envMan.reset_environment()
    episode_reward_clipped = 0
    episode_reward         = 0
    episode_loss           = 0
    episode_frames         = []
    
    # keep going through the episode until it's done
    for timestep in count(): # count is a py-function that keeps track of current iteration

        # render the env, comment next line to stop showing the network playing
        envMan.env.render()

        # make the agent choose an action
        action = agent.choose_action(policy_net, state.to(net_device))

        # perform the action
        next_state, reward, done, info, frame, reward_unc = envMan.perform_action(action)
        episode_reward_clipped += reward.item()
        episode_reward         += reward_unc
        episode_frames.append(frame.T)
        
        # add the experience to replay memory, then update the current state 
        replay_memory.add_experience(Experience(state, action, reward, next_state, done))
        state = next_state
        
        # check if replay memory is samplable
        if replay_memory.is_samplable(config['batch_size']):
            
            # sample the replay memory
            states, actions, rewards, next_states, is_terminal = replay_memory.sample(config['batch_size'])
            
            # get the values of Q and Q_prime
            q_value = Qvalues.get_q_values(policy_net, states.to(net_device), actions.to(net_device))
            q_prime = Qvalues.get_q_prime_values(target_net, next_states.to(net_device), is_terminal.to(net_device))
            q_star  = (rewards.unsqueeze(-1).to(net_device) + (config['gamma'] * q_prime)).float() # refer to the bellman eqn
        
            # calculate loss, do back prop
            loss = loss_fn(q_value, q_star)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            episode_loss += loss.item() # track the episode loss
            
        # if done, log losses & statistics and start the next episode
        if done.item():
            run.log({
                'EpisodeLoss':   episode_loss,
                'EpisodeReward': episode_reward_clipped,
                'EpisodeRewardUnclipped': episode_reward,
                'EpisodeLength': len(episode_frames)
            })
            break
    
    # sync the target and policy nets
    if timestep % config['sync_time'] == 0:
        target_net.load_state_dict(policy_net.state_dict())
    
    # periodically save the episode videos
    if episode % 1000 == 0:
        
        # save the episode locally for back up
        DataUtils.save_episode(episode_frames, episode, run_videos)
        
        # expects the video to be of shape (t, c, h, w)
        video = wandb.Video(np.array(episode_frames).swapaxes(2, 3), fps=24, caption = str(episode))
        wandb.log({'Episode-' + str(episode) : video})
    
    # periodically save the model weights
    if episode % 10000 == 0:
        
        # save locally for back up
        torch.save(policy_net.state_dict(), os.path.join(run_checkpoints, str(episode)))
        
        # use wandb artifact to save the model
        artifact.add_file(os.path.join(run_checkpoints, str(episode)))
        
        
# close the environment
run.log_artifact(artifact)
run.join()
envMan.close_environment()