In [None]:
import sys
import torch  
import gym
import numpy as np  
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf

# hyperparameters
hidden_size = 256
learning_rate = 0.00001

# Constants
GAMMA = 0.99
num_steps = 30000
num_frames = 4
max_episodes = 3000

class ActorCritic(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size, learning_rate=3e-4):
        super(ActorCritic, self).__init__()

        self.num_frames = 4
        self.img_row = 84
        self.img_col = 84
        self.image_frames = np.zeros((self.img_row, self.img_col, self.num_frames))
        self.num_actions = num_actions

        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4, bias=False)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, bias=False)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, bias=False)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=7, stride=1, bias=False)

        self.critic_linear1 = nn.Linear(64, hidden_size)
        self.critic_linear2 = nn.Linear(hidden_size, 1)

        self.actor_linear1 = nn.Linear(64, hidden_size)
        self.actor_linear2 = nn.Linear(hidden_size, num_actions)
    
    def forward(self, state):
        # print("In state", state.shape)
        state = np.reshape(state, (state.shape[2], state.shape[0], state.shape[1]))
        state = Variable(torch.from_numpy(state).float().unsqueeze(0))
        state /= 255

        conv_output = self.conv1(state)
        conv_output = self.conv2(conv_output)
        conv_output = self.conv3(conv_output)
        conv_output = self.conv4(conv_output)

        value = F.relu(self.critic_linear1(torch.flatten(conv_output, start_dim=1)))
        value = self.critic_linear2(value)
        
        policy_dist = F.relu(self.actor_linear1(torch.flatten(conv_output, start_dim=1)))
        policy_dist = F.softmax(self.actor_linear2(policy_dist), dim=1)

        return value, policy_dist

    def process_image(self, cur_img_state):

        # Resize as in the implementation
        processed = tf.image.rgb_to_grayscale(cur_img_state)
        processed = tf.image.crop_to_bounding_box(processed, 34, 0, 160, 160)
        cur_img_resized = tf.image.resize(processed, [84, 84], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

        self.image_frames = np.roll(self.image_frames, shift=1, axis=2) # A circular shift of 1 applied along the frame axis
        self.image_frames[:, :, 0] = np.squeeze(cur_img_resized) # Overwrite the oldest frame with the current resized image

        return self.image_frames

def a2c(env):
    num_inputs = env.observation_space.shape
    num_outputs = env.action_space.n
    
    # print(num_inputs, num_outputs)
    actor_critic = ActorCritic(num_inputs, num_outputs, hidden_size)
    ac_optimizer = optim.Adam(actor_critic.parameters(), lr=learning_rate)

    all_lengths = []
    average_lengths = []
    all_rewards = []
    entropy_term = 0

    for episode in range(max_episodes):
        log_probs = []
        values = []
        rewards = []

        state = env.reset()
        done = False

        for steps in range(num_steps):
            while not done:
                # preprocess the state
                state_processed = actor_critic.process_image(state)
                value, policy_dist = actor_critic.forward(state_processed)

                value = value.detach().numpy()[0,0]
                dist = policy_dist.detach().numpy() 
            

                action = np.random.choice(num_outputs, p=np.squeeze(dist))
                log_prob = torch.log(policy_dist.squeeze(0)[action])
                entropy = -np.sum(np.mean(dist) * np.log(dist))
                new_state, reward, done, _ = env.step(action)

                rewards.append(reward)
                values.append(value)
                log_probs.append(log_prob)
                entropy_term += entropy
                state = new_state
            

            new_state_processed = actor_critic.process_image(new_state)
            Qval, _ = actor_critic.forward(new_state_processed)
            Qval = Qval.detach().numpy()[0,0]
            all_rewards.append(np.sum(rewards))
            all_lengths.append(steps)
            average_lengths.append(np.mean(all_lengths[-10:]))
            if episode % 10 == 0:                    
                sys.stdout.write("episode: {}, reward: {}, total length: {}, average length: {} \n".format(episode, np.sum(rewards), steps, average_lengths[-1]))
            break
        
        # compute Q values
        Qvals = np.zeros_like(values)
        for t in reversed(range(len(rewards))):
            Qval = rewards[t] + GAMMA * Qval
            Qvals[t] = Qval
  
        #update actor critic
        values = torch.FloatTensor(values)
        Qvals = torch.FloatTensor(Qvals)
        log_probs = torch.stack(log_probs)
        
        advantage = Qvals - values
        actor_loss = (-log_probs * advantage).mean()
        critic_loss = 0.5 * advantage.pow(2).mean()
        ac_loss = actor_loss + critic_loss + 0.001 * entropy_term

        ac_optimizer.zero_grad()
        ac_loss.backward()
        ac_optimizer.step()

        
    
    # Plot results
    smoothend_rewards = pd.Series.rolling(pd.Series(all_rewards), 10).mean()
    smoothend_rewards = [elem for elem in smoothend_rewards]
    plt.plot(all_rewards)
    plt.plot(smoothend_rewards)
    plt.plot()
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.show()

    plt.plot(all_lengths)
    plt.plot(average_lengths)
    plt.xlabel('Episode')
    plt.ylabel('Episode length')
    plt.show()

if __name__ == "__main__":
    env = gym.make("BreakoutDeterministic-v4")
    a2c(env)  