In [1]:
import math
import random

import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal, Categorical
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from collections import deque
import cv2

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [9]:
env = gym.make("SpaceInvaders-v0")
epsilon = 1.0
epsilonMin = 0.01
epsilonDecay = 0.999
episodes = 1000
batch_size = 64
gamma = 0.99
goal_steps = 200
num_actions = env.action_space.n
# print(type(num_actions))
buffer_capacity = 500
epochs = 1
clip_param = 0.2

In [3]:
# Memory to save the experiences 
class Buffer(object):
    def __init__(self):
        self.buffer = []
        self.buffer_capacity = 1000
        self.batch = 32
    
    def add(self, params):
        self.buffer.append(params)
        
    def reinit(self):
        self.buffer = []
        
    def length(self):
        return len(self.buffer)

In [4]:
# Preprocess reduces dimension and converts frame of 210x160x3 to 84x84
def preprocess(observation):
    observation = cv2.cvtColor(cv2.resize(observation, (84, 110)), cv2.COLOR_BGR2GRAY)
    observation = observation[26:110,:]
    ret, observation = cv2.threshold(observation, 1, 255, cv2.THRESH_BINARY)
    return np.reshape(observation, (84, 84))

In [5]:
# Stack 4 frames (4 frames used to give idea of motion) to create data set
stack_size = 4
stacked_frames  =  deque([np.zeros((84,84), dtype=np.int) for i in range(stack_size)], maxlen=4)

def stack_frames(stacked_frames, state, is_new_episode):
    frame = preprocess(state)
    if is_new_episode:
        stacked_frames = deque([np.zeros((84,84), dtype=np.int) for i in range(stack_size)], maxlen=4)
        
        stacked_frames.append(frame)
        stacked_frames.append(frame)
        stacked_frames.append(frame)
        stacked_frames.append(frame)
        
        stacked_state = np.stack(stacked_frames, axis=0)
        
    else:
        stacked_frames.append(frame)
        stacked_state = np.stack(stacked_frames, axis=0)
    return stacked_state, stacked_frames

In [6]:
# Network for Actor and Critic
class Actor(nn.Module):
    def __init__(self, input_shape1, num_actions1):
        super(Actor, self).__init__()
        self.conv1 = nn.Conv2d(input_shape1, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc4 = nn.Linear(7*7*64, 512)
        self.fc5 = nn.Linear(512, num_actions1)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.fc4(x.view(x.size(0), -1)))
        mu = F.softmax(self.fc5(x))
        
        c = Categorical(mu)
        return c

class Critic(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(Critic, self).__init__()
        self.conv1 = nn.Conv2d(input_shape, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc4 = nn.Linear(7*7*64, 512)
        self.fc5 = nn.Linear(512, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.fc4(x.view(x.size(0), -1)))
        value = self.fc5(x)
        return value

In [7]:
# Update method for both policy and value
def update():
    mem = memory.buffer
    # print(mem[0][0].numpy().squeeze(0).shape)
    s = torch.FloatTensor([m[0].numpy().squeeze(0) for m in mem])
    a = torch.FloatTensor([m[1] for m in mem]) 
    old_log_a = torch.FloatTensor([m[2] for m in mem])
    r = torch.FloatTensor([m[3] for m in mem])
    s_ = torch.FloatTensor([m[4] for m in mem])
    
    with torch.no_grad():
        target = r + gamma * critic(s_).squeeze(1)
    adv = target - critic(s).squeeze(1)
    
    for _ in range(epochs):
        print('1')
        for id in BatchSampler(SubsetRandomSampler(range(buffer_capacity)), batch_size, False):
            dist = actor(s[id])
            new_log_a = dist.log_prob(a[id].unsqueeze(1))
            ratio = torch.exp(new_log_a.squeeze(1) - old_log_a[id])
            
            loss1 = ratio * adv[id]
            loss2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * adv[id]
            action_loss = -torch.min(loss1, loss2).mean()
#             print(type(action_loss))
            opt_a.zero_grad()
            action_loss.backward(retain_graph=True)
            nn.utils.clip_grad_norm_(actor.parameters(), 0.5)
            opt_a.step()

            value_loss = (critic(s[id]).squeeze(1) - target[id]).pow(2).mean()
#             print(type(value_loss))
            opt_c.zero_grad()
            value_loss.backward(retain_graph=True)
            nn.utils.clip_grad_norm_(critic.parameters(), 0.5)
            opt_c.step()
    
    memory.reinit()

In [8]:
state = env.reset()
state, stacked_frames = stack_frames(stacked_frames, state, True)
input_size = state.shape[0]
# print(type(input_size))
actor = Actor(input_size, num_actions)
critic = Critic(input_size, num_actions)

opt_a = optim.Adam(actor.parameters(), lr=1e-4)
opt_c = optim.Adam(critic.parameters(), lr=3e-4)

In [9]:
memory = Buffer()

for idx in range(episodes):
    state = env.reset()
    state, stacked_frames = stack_frames(stacked_frames, state, True)
    done = False
    score = 0
    while not done:
        state = torch.FloatTensor(state).unsqueeze(0)
#         print(state.shape)
        value = critic(state)
        dist = actor(state)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        # action = action.clamp(-2, 2)
        next_state, reward, done, _ = env.step(action.numpy())
        next_state, stacked_frames = stack_frames(stacked_frames, next_state, False)
        score += reward
        memory.add([state, action, log_prob, reward, next_state])
        state = next_state
        
        if memory.length() >= buffer_capacity:
            # print(memory.buffer[0])
            update()
    print("Episode = " + str(idx) + ", Score = " + str(score))

  app.launch_new_instance()


Episode = 0, Score = 210.0
1
1
1
1
1
Episode = 1, Score = 105.0
1


KeyboardInterrupt: 