In [155]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Categorical
import gymnasium as gym

In [161]:
class Actor2Critic(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.backbone = nn.Sequential(
            nn.Linear(in_channels, 128),
            nn.ReLU(),
            #
            nn.Linear(128, 64),
            nn.ReLU(),
            #
            nn.Linear(64, 32),
            nn.ReLU(),
        )

        self.actor_head = nn.Sequential(
            nn.Linear(32, out_channels),
        )

        self.critic_head = nn.Sequential(
            nn.Linear(32, 1),
        )

    def forward(self, x):
        y = self.backbone(x)

        actor = self.actor_head(y)
        critic = self.critic_head(y)

        return actor, critic

In [172]:
def compute_returns(next_value, rewards, masks, gamma):
    R = next_value
    returns = []
    for i in reversed(range(rewards)):
        R = rewards[i] + gamma * R * masks[i]
        returns.insert(0, R)

    return returns

In [173]:
# Params
LR = 0.01
BATCH_SIZE = 32
GAMMA = 0.99

In [174]:
env = gym.make('CartPole-v1')
n_observations = env.observation_space.shape[0]
n_actions = env.action_space.n

net = Actor2Critic(n_observations, n_actions)
opt = torch.optim.Adam(net.parameters(), lr=LR)
mse_criterion = nn.MSELoss()
writer = SummaryWriter('./runs/little_network')

In [None]:
epochs = 700

state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float32).view(1, -1)  # 1xn_obs
r_max = 0

for epoch in range(epochs):
    total_loss = torch.tensor([0.0])
    R = 0

    for i in range(BATCH_SIZE):
        action_logits, value = net(state)  # 1xn_acts, 1x1
        dist = Categorical(logits=action_logits)

        # apply action
        action = dist.sample()  # [1]
        next_state, reward, terminated, truncated, _ = env.step(action.item())
        R += reward

        next_state = torch.tensor(next_state, dtype=torch.float32).view(1, -1)  # 1xn_obs
        reward = torch.tensor(reward, dtype=torch.float32).view(1, -1)  # 1x1

        _, next_value = net(next_state)

        # calculate losses
        td_target = reward + GAMMA * next_value  # r + gamma * V(s')
        advantage = td_target - value  # r + gamma * V(s') - V(s)
        log_prob = dist.log_prob(action)

        value_loss = mse_criterion(value, td_target).view(1)  #
        actor_loss = -(advantage * log_prob).view(1)  # A(s, a) log(pi(a | s))

        total_loss += value_loss + actor_loss

        # prepare next state
        state = next_state

        if terminated or truncated:
            state, _ = env.reset()
            state = torch.tensor(state, dtype=torch.float32).view(1, -1)  # 1xn_obs
            writer.add_scalar('total/reward', R, epoch)
            if R > r_max:
                r_max = R

            R = 0
            break

    opt.zero_grad()
    (total_loss / BATCH_SIZE).backward()

    # statistical analysis
    grads = torch.cat([p.grad.flatten() for p in net.parameters() if p.grad is not None])
    max_grad = grads.max()
    var_grad = grads.var()
    norm_grad = grads.norm()

    writer.add_scalar('grads/max', max_grad, epoch)
    writer.add_scalar('grads/var', var_grad, epoch)
    writer.add_scalar('grads/norm', norm_grad, epoch)

    opt.step()

    print(f'Episode {epoch} loss: {total_loss.item()}, r_max: {r_max}')

Episode 0 loss: 13.541119575500488, r_max: 0
Episode 1 loss: 11.712241172790527, r_max: 7.0
Episode 2 loss: 12.727513313293457, r_max: 7.0
Episode 3 loss: 10.221213340759277, r_max: 7.0
Episode 4 loss: 12.619786262512207, r_max: 7.0
Episode 5 loss: 14.82143783569336, r_max: 7.0
Episode 6 loss: 14.097794532775879, r_max: 7.0
Episode 7 loss: 6.197763919830322, r_max: 7.0
Episode 8 loss: 14.483087539672852, r_max: 7.0
Episode 9 loss: 10.11109733581543, r_max: 7.0
Episode 10 loss: 12.749632835388184, r_max: 8.0
Episode 11 loss: 11.435988426208496, r_max: 8.0
Episode 12 loss: 2.957047939300537, r_max: 8.0
Episode 13 loss: 12.187858581542969, r_max: 8.0
Episode 14 loss: 3.6430439949035645, r_max: 8.0
Episode 15 loss: 5.176426410675049, r_max: 8.0
Episode 16 loss: 0.8184529542922974, r_max: 8.0
Episode 17 loss: 2.2613987922668457, r_max: 8.0
Episode 18 loss: 0.06633568555116653, r_max: 8.0
Episode 19 loss: 0.0651925727725029, r_max: 8.0
Episode 20 loss: 0.362490713596344, r_max: 8.0
Episode 2

In [168]:
torch.save(net.state_dict(), './a2c.pth')