In [6]:
# Standard library imports
import random, os

# Third-party imports
import torch
from torch import Tensor, nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
from tqdm.notebook import tqdm
from numpy._typing import _ShapeLike

# Local application imports
from agent_base import AgentBase
import utils

In [5]:
class ActorCriticNetwork(nn.Module):
    def __init__(self, n_inp, n_actions, features = [128, 128]):
        super(ActorCriticNetwork, self).__init__()
        layer_sizes = [n_inp] + features

        self.encoded = nn.Sequential()

        for i in range(len(layer_sizes) - 1):
            self.encoded.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
            self.encoded.append(nn.ReLU(inplace=True))
        self.critic = nn.Linear(layer_sizes[-1], 1)
        
        self.actor = nn.Linear(layer_sizes[-1], n_actions)
 
    def forward(self, X: Tensor):
        encoded = self.encoded(X)
        return self.critic(encoded), self.actor(encoded)

In [41]:
class ActorCritic(AgentBase):
    def __init__(
            self, 
            state_shape: _ShapeLike, 
            action_shape: _ShapeLike, 
            n_actions: int,
            lr = 3e-5,
            gamma = 0.99,
            batch_size: int = 64, 
            max_mem_size: int = int(1e5),  
            update_every: int = 1, 
            device: str = 'cuda' if torch.cuda.is_available() else 'cpu', 
            seed = 0, 
            **kwargs
        ) -> None:
        super().__init__(state_shape, action_shape, batch_size, max_mem_size, update_every, device, seed, **kwargs)

        self.gamma = gamma
        self.device = device

        self.actor_critic = ActorCriticNetwork(state_shape[0], n_actions).to(self.device)

        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=lr)
    
    @torch.no_grad()
    def act(self, state: np.ndarray):
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        _, action_probs = self.actor_critic(state)[0]
        action_probs = F.softmax(action_probs, dim=0)
        action = torch.distributions.Categorical(action_probs).sample()
        self.action = action
        return action.item()
    
    def learn(self, state: Tensor, reward: Tensor, next_state: Tensor, terminal: Tensor):

        action_value, probs = self.actor_critic(state)[0]
        next_action_value, _ = self.actor_critic(next_state)[0]

        action_probs = F.softmax(probs, dim=0)
        action_probs = torch.distributions.Categorical(action_probs)
        log_prob: Tensor = action_probs.log_prob(self.action)

        delta = reward + self.gamma * next_action_value * (~terminal) - action_value
        actor_loss = -log_prob * delta
        critic_loss = delta.pow(2)

        total_loss = actor_loss + critic_loss

        self.actor_critic.zero_grad()
        total_loss.backward()
        self.optimizer.step()



In [44]:
env = utils.make('CartPole-v1')

Discrete(2)

In [46]:
agent = ActorCritic(
    state_shape = env.observation_space.shape,
    action_shape = (1, ),
    n_actions = env.action_space.n,
)

In [None]:
n_games = 1000
scores = []
for i in tqdm(range(n_games)):
    done = False
    state = env.reset()[0]
    score = 0
    while not done:
        action = agent.act(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated

        agent.learn(state, reward, next_state, done)

        state = next_state
        score += reward
        
    scores.append(score)
    avg_score = np.mean(scores[-100:])
    print(f'Episode: {i}, Score: {score}, Avg Score: {avg_score}')
