In [5]:
#implement actor critic algorithm for Catch environment from scratch using pytorch

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from catch import Catch
import argparse
import gym
from itertools import count
from collections import namedtuple
import torch.nn.functional as F


In [6]:
# #Define the actor critic network architecture
# class ActorCritic(nn.Module):

#     def __init__(self):
#         super(ActorCritic, self).__init__()
#         self.affine1 = nn.Linear(2, 128)

#         # actor's layer
#         self.action_head = nn.Linear(128, 2)

#         # critic's layer
#         self.value_head = nn.Linear(128, 1)

#         # action & reward buffer
#         self.saved_actions = []
#         self.rewards = []

#     def forward(self, x):
#         """
#         forward of both actor and critic
#         """
#         x = F.relu(self.affine1(x))

#         # actor: choses action to take from state s_t
#         # by returning probability of each action
#         action_prob = F.softmax(self.action_head(x), dim=-1)

#         # critic: evaluates being in the state s_t
#         state_values = self.value_head(x)

#         # return values for both actor and critic as a tuple of 2 values:
#         # 1. a list with the probability of each action over the action space
#         # 2. the value from state s_t
#         return action_prob, state_values


# model = ActorCritic()
# optimizer = optim.Adam(model.parameters(), lr=3e-2)
# eps = np.finfo(np.float32).eps.item()

In [7]:
#Define the Actor and Critic networks architecture

class Actor(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        x = torch.softmax(self.fc3(x), dim=-1)
        return x

class Critic(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [8]:
class ActorCriticAgent:
    def __init__(self, input_size, output_size, hidden_size, gamma, lr, trace_length, device, use_baseline=False, use_bootstrapping=False, entropy_strength=0.01):
        self.actor = Actor(input_size, output_size, hidden_size).to(device)
        self.critic = Critic(input_size, hidden_size).to(device)
        self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=lr)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=lr)
        self.gamma = gamma
        self.device = device
        self.use_baseline = use_baseline
        self.use_bootstrapping = use_bootstrapping
        self.entropy_weight = entropy_strength
        self.trace_length = trace_length

    # def get_action(self, state):
    #     state = torch.FloatTensor(state).to(self.device)
    #     policy_probs = self.actor(state)
    #     action = torch.distributions.Categorical(policy_probs).sample()
    #     return action.item()

    def train(self, episodes):
        env = Catch(grid_size=7)

        for episode in range(episodes):
            state = env.reset()
            done = False

            states = []
            actions = []
            rewards = []
            log_probs = []

            while not done:
                state = torch.FloatTensor(state).to(self.device)
                policy_probs = self.actor(state)
                action_dist = torch.distributions.Categorical(policy_probs)
                action = action_dist.sample()
                log_prob = action_dist.log_prob(action)

                next_state, reward, done = env.step(action.item())

                states.append(state)
                actions.append(action)
                rewards.append(reward)
                log_probs.append(log_prob)

                state = next_state

            # Compute advantages
            values = self.critic(torch.stack(states).to(self.device)).squeeze()
            next_values = torch.cat((values[1:], torch.zeros(1).to(self.device)))
            delta = rewards + self.gamma * next_values - values
            advantages = self.compute_advantages(env, delta)

            # Compute policy loss
            log_probs = torch.stack(log_probs)
            policy_loss = -torch.mean(log_probs * advantages) - self.entropy_strength * torch.mean(policy_probs * torch.log(policy_probs))

            # Compute value loss
            targets = rewards + self.gamma * next_values
            value_loss = nn.MSELoss()(values, targets)

            # Compute total loss
            loss = policy_loss + value_loss

            # Update actor and critic networks
            self.optimizer_actor.zero_grad()
            self.optimizer_critic.zero_grad()
            loss.backward()
            self.optimizer_actor.step()
            self.optimizer_critic.step()
            
    def compute_advantages(self, env, delta):
        if self.use_baseline:
            delta -= torch.mean(delta)
        if self.use_bootstrapping:
            delta += self.gamma * self.critic(torch.FloatTensor(env.reset()).to(self.device)).squeeze()
        return delta

           