In [1]:
# %%
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter
from utils import write_grad_info

2025-03-01 17:40:40.565138: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-01 17:40:40.575750: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740843640.588290  525107 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740843640.591684  525107 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-01 17:40:40.604397: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:

class Actor2Critic(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.backbone = nn.Sequential(
            nn.Linear(in_channels, 128),
            nn.ReLU(),
            #
            nn.Linear(128, 64),
            nn.ReLU(),
            #
            nn.Linear(64, 32),
            nn.ReLU(),
        )

        self.actor_head = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, out_channels),
        )

        self.critic_head = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
        )
    def forward(self, x):
        y = self.backbone(x)

        actor = self.actor_head(y)
        critic = self.critic_head(y)

        return actor, critic

In [3]:
# Hyperparams
GAMMA = 0.99
LR = 0.001
BATCH_SIZE = 16
ENTROPY_BETA = 0.1

In [4]:
env = gym.make('CartPole-v1')
n_obserations = env.observation_space.shape[0]
n_actions = env.action_space.n

net = Actor2Critic(n_obserations, n_actions)
opt = optim.Adam(net.parameters(), LR)
critic_criterion = nn.MSELoss()
writer = SummaryWriter()

In [5]:
episodes = 1000
step = 0

for episode in range(episodes):
    state, _ = env.reset()
    state = torch.tensor(state)

    total_actor_loss = torch.tensor(0.0)
    total_critic_loss = torch.tensor(0.0)
    total_entropy_loss = torch.tensor(0.0)

    R = 0

    while True:
        step += 1

        action_dist, value = net(state)
        categorical = Categorical(logits=action_dist)

        action = categorical.sample()
        next_state, reward, terminated, truncated, _ = env.step(action.item())
        next_state = torch.tensor(next_state)
        if reward != 1:
            reward = -10
        R += reward

        _, next_value = net(next_state)

        td_target = reward + GAMMA * next_value

        advantages = (value - td_target).detach()
        policy_loss = -(advantages * torch.log(categorical.probs[action.item()])).mean()
        value_loss = critic_criterion(value, td_target)

        total_actor_loss += policy_loss
        total_critic_loss += value_loss
        total_entropy_loss += -ENTROPY_BETA * categorical.entropy()

        if step % BATCH_SIZE == 0:
            opt.zero_grad()

            actor_grad = torch.autograd.grad(total_actor_loss, net.parameters(), retain_graph=True, allow_unused=True)
            critic_grad = torch.autograd.grad(total_critic_loss, net.parameters(), retain_graph=True, allow_unused=True)

            total_loss = (total_actor_loss + total_critic_loss + total_entropy_loss) / BATCH_SIZE
            total_loss.backward()

            write_grad_info(actor_grad, writer, 'actor', step, is_grad=True)
            writer.add_scalar('actor/loss', total_actor_loss.item(), step)

            write_grad_info(critic_grad, writer, 'critic', step, is_grad=True)
            writer.add_scalar('critic/loss', total_critic_loss.item(), step)

            write_grad_info(net.parameters(), writer, 'total', step, is_grad=False)
            writer.add_scalar('total/loss', total_loss.item(), step)

            opt.step()

            total_critic_loss = torch.tensor(0.0)
            total_actor_loss = torch.tensor(0.0)
            total_entropy_loss = torch.tensor(0.0)

        if terminated or truncated:
            break

        state = next_state
    writer.add_scalar('reward', R, episode)
    print(f'episode: {episode}, R: {R}')

episode: 0, R: 22.0
episode: 1, R: 31.0
episode: 2, R: 12.0
episode: 3, R: 15.0
episode: 4, R: 26.0
episode: 5, R: 39.0
episode: 6, R: 12.0
episode: 7, R: 12.0
episode: 8, R: 13.0
episode: 9, R: 23.0
episode: 10, R: 33.0
episode: 11, R: 39.0
episode: 12, R: 14.0
episode: 13, R: 16.0
episode: 14, R: 16.0
episode: 15, R: 16.0
episode: 16, R: 34.0
episode: 17, R: 13.0
episode: 18, R: 17.0
episode: 19, R: 21.0
episode: 20, R: 19.0
episode: 21, R: 16.0
episode: 22, R: 12.0
episode: 23, R: 12.0
episode: 24, R: 39.0
episode: 25, R: 29.0
episode: 26, R: 18.0
episode: 27, R: 15.0
episode: 28, R: 29.0
episode: 29, R: 30.0
episode: 30, R: 22.0
episode: 31, R: 15.0
episode: 32, R: 36.0
episode: 33, R: 28.0
episode: 34, R: 20.0
episode: 35, R: 27.0
episode: 36, R: 12.0
episode: 37, R: 16.0
episode: 38, R: 10.0
episode: 39, R: 13.0
episode: 40, R: 12.0
episode: 41, R: 10.0
episode: 42, R: 23.0
episode: 43, R: 21.0
episode: 44, R: 36.0
episode: 45, R: 16.0
episode: 46, R: 15.0
episode: 47, R: 23.0
ep

In [6]:
torch.save(net.state_dict(), './a2c.pth')