In [1]:
# %%
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter
from utils import write_grad_info

2025-03-02 04:14:24.393710: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-02 04:14:24.403807: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740881664.417603  530957 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740881664.421660  530957 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-02 04:14:24.436045: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:

class Actor2Critic(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.backbone = nn.Sequential(
            nn.Linear(in_channels, 128),
            nn.ReLU(),
            #
            nn.Linear(128, 64),
            nn.ReLU(),
            #
            nn.Linear(64, 32),
            nn.ReLU(),
        )

        self.actor_head = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, out_channels),
        )

        self.critic_head = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
        )
    def forward(self, x):
        y = self.backbone(x)

        actor = self.actor_head(y)
        critic = self.critic_head(y)

        return actor, critic

In [3]:
# Hyperparams
GAMMA = 0.99
LR = 0.001
BATCH_SIZE = 16
START_BETA = 1
BETA_END = 0.0001
BETA_DECAY = 10000

In [4]:
env = gym.make('CartPole-v1')
n_obserations = env.observation_space.shape[0]
n_actions = env.action_space.n

net = Actor2Critic(n_obserations, n_actions)
opt = optim.Adam(net.parameters(), LR)
critic_criterion = nn.MSELoss()
writer = SummaryWriter()

In [6]:
episodes = 10000
step = 0
beta = START_BETA

for episode in range(episodes):
    state, _ = env.reset()
    state = torch.tensor(state)

    total_loss = torch.tensor(0.0)

    R = 0

    while True:
        beta = BETA_END + (START_BETA - BETA_END) * max((1 - step / BETA_DECAY), 0)
        step += 1

        action_dist, value = net(state)
        categorical = Categorical(logits=action_dist)

        action = categorical.sample()
        next_state, reward, terminated, truncated, _ = env.step(action.item())
        next_state = torch.tensor(next_state)
        R += reward


        next_value = net(next_state)[1] * (truncated or terminated) # zero when there is no next state

        td_target = reward + GAMMA * next_value

        advantages = (td_target - value).detach() # (value - td_target).detach()
        policy_loss = -(advantages * torch.log(categorical.probs[action.item()])).mean()
        value_loss = critic_criterion(value, td_target)
        entropy_loss = -beta * categorical.entropy()

        total_loss += policy_loss + value_loss + entropy_loss + entropy_loss

        if step % BATCH_SIZE == 0:
            opt.zero_grad()

            (total_loss / BATCH_SIZE).backward()

            write_grad_info(net.parameters(), writer, 'total', step, is_grad=False)
            writer.add_scalar('total/loss', total_loss.item(), step)

            opt.step()

            total_loss = torch.tensor(0.0)

        if terminated or truncated:
            break

        state = next_state
    writer.add_scalar('reward', R, episode)
    print(f'episode: {episode}, R: {R}, beta: {beta}')

episode: 0, R: 11.0, beta: 0.9990001
episode: 1, R: 14.0, beta: 0.99760024
episode: 2, R: 17.0, beta: 0.99590041
episode: 3, R: 10.0, beta: 0.99490051
episode: 4, R: 12.0, beta: 0.99370063
episode: 5, R: 28.0, beta: 0.99090091
episode: 6, R: 33.0, beta: 0.98760124
episode: 7, R: 18.0, beta: 0.98580142
episode: 8, R: 23.0, beta: 0.98350165
episode: 9, R: 16.0, beta: 0.98190181
episode: 10, R: 29.0, beta: 0.9790021
episode: 11, R: 14.0, beta: 0.9776022400000001
episode: 12, R: 10.0, beta: 0.97660234
episode: 13, R: 16.0, beta: 0.9750025
episode: 14, R: 9.0, beta: 0.97410259
episode: 15, R: 16.0, beta: 0.9725027500000001
episode: 16, R: 25.0, beta: 0.970003
episode: 17, R: 12.0, beta: 0.96880312
episode: 18, R: 19.0, beta: 0.96690331
episode: 19, R: 17.0, beta: 0.96520348
episode: 20, R: 12.0, beta: 0.9640036
episode: 21, R: 19.0, beta: 0.9621037899999999
episode: 22, R: 15.0, beta: 0.96060394
episode: 23, R: 12.0, beta: 0.95940406
episode: 24, R: 12.0, beta: 0.9582041800000001
episode: 2

KeyboardInterrupt: 

In [None]:
torch.save(net.state_dict(), './a2c.pth')