A C++ implementation of reinforcement learning algorithms using libtorch.
🚧 This repository is currently under development 🚧
Emulates the snake game with a NN-friendly observation space. Defining aspects of the Snake game:
- Sparse rewards: chance of randomly encountering an apple is
$\frac{1}{H*W}$ - Procedurally generated: avoids overfitting by memorizing a trajectory, many possible states
- Difficulty scaling: complexity of the solution ramps up as the agent progresses, snake gets longer
- Very fast to update
- Easy to implement
Fully observable with each position containing [x,y,c] where x and y correspond to width/height and
-
$c_0$ snake body -
$c_1$ snake head -
$c_2$ food -
$c_3$ wall
- REWARD_COLLISION = -1
- REWARD_APPLE = 1
- REWARD_MOVE = -0.01
- LEFT = 0
- STRAIGHT = 1
- RIGHT = 2
1. Vanilla Policy Gradient 1
This method is the basis for a large family of modern RL algorithms which operate by directly maximizing the expected reward or value of the actions taken by a policy.
where
-
$n$ is the current episode index. -
$N$ is the total number episodes.
The decay terminates when
2. Policy Gradient with entropy regularization 2
Because models often converge on shallow maxima, their subsequent rollouts can become limited in exploration. Entropy regularization adds a reward for random behavior which is intended to attenuate against a feedback loop between action sampling (generating the training data) and bias in the policy.
In practice, the entropy term is balanced against the reward term to avoid early collapse, and allow high reward rollouts to dwarf the entropy reward. In this sense, the entropy regularizer is a convenient and adaptive method to preventing early convergence, compared to epsilon scheduling, which is typically hardcoded.
where
and
Entropy is maximized when the action distribution emitted by the policy
This method employs a second network that approximates the value function for a given state. During training, the actor attempts to maximize the advantage, which captures how much better a particular action is than the expected value predicted by the critic. Subtracting this baseline reduces the variance of the policy gradient and provides more stable updates, even when the absolute scale of the rewards varies widely. While the critic loss remains sensitive to raw reward magnitude, the advantage-based policy (actor) update is relatively well-behaved.
4. Asynchronous Actor-Critic Agent (A3C) 3
This implementation of A3C makes use of a specialized, thread safe, parameter optimizer, RMSPropAsync, which combines gradients from worker threads to update a shared parameter set. The shared parameter set is then distributed back to the workers. It is not lock-free as the original A3C paper claims to be, but it offers a reasonably low contention alternative for which each module in the neural network has a separate mutex associated with it. The A3CAgent class initializes a thread pool of A2CAgents which have a synchronization lambda function, for simplicity and modularity.
RMSPropAsync takes the same form as standard non-centered RMSProp, with some changes to make it compatible with multiple workers. The gradients computed by the workers are returned to RMSPropAsync, which then updates its internal state, tracking the exponential moving average (EMA) of the squared gradient.
-
$\theta$ : The shared model parameter being updated. -
$\eta$ : The learning rate. -
$g_w$ : The gradient of the loss with respect to$\theta$ from worker$w$ . -
$G$ : The exponentially weighted moving average of squared gradients (RMSProp accumulator). -
$\alpha$ : The decay factor for the moving average. -
$\epsilon$ : A small constant added for numerical stability.
The scaling behavior of RMSPropAsync is as follows:
Where "naive" refers to the simple linear iteration of the vector of parameters, and the others use finer
grained, uniform chunk sizes, and random iteration order. Episode lengths (
PPO's approach to increasing sample efficiency is to perform trajectory sampling in large chunks and then batched training on those trajectories. This method is convenient for deployment because the trajectory sampling can be performed entirely without gradient tracking (inference only), and then localized for high efficiency training.
The focus of their publication is a loss term which acts as a surrogate for a KL divergence penalty, referred to as
where
And
where
Taken by itself,
The
(Signs are flipped relative to the publication to make it clear that the L stands for Loss which should be minimized)
| Layer | Dimensions |
|---|---|
| input | w*h*4 |
| fc | 256 |
| layernorm | - |
| GELU | - |
| fc | 256 |
| layernorm | - |
| GELU | - |
| fc | 256 |
| layernorm | - |
| GELU | - |
| fc | output_size |
| log_softmax | output_size |
Densenet with 2 convolution layers and CBAM spatial/channel attention 6 7
| Layer | Dimensions |
|---|---|
| Input | input_width * input_height * input_channels |
| Conv2D (conv1) | 8 filters, kernel=3x3, stride=1, padding=1 |
| GELU | - |
| Concat | input + conv1 output |
| Conv2D (conv2) | 16 filters, kernel=3x3, stride=1, padding=1 |
| GELU | - |
| Concat | input + conv1 output + conv2 output |
| Channel Attention | input + conv1 output + conv2 output |
| Spatial Attention | input + conv1 output + conv2 output |
| Residual Add | input |
| Flatten | input_width * input_height * (input_channels + 8 + 16) |
| Fully Connected (fc1) | 256 |
| LayerNorm (ln1) | 256 |
| GELU | - |
| Fully Connected (fc2) | 128 |
| LayerNorm (ln2) | 128 |
| GELU | - |
| Fully Connected (fc3) | output_size |
| Log Softmax (if multiple outputs) | output_size |
🍒 These demos are chosen arbitrarily by me and are not a robust indication of the relative performance of these methods
An example of a mildly successful Policy Gradient SimpleConv agent trained with entropy regularization. You can see that it has converged on a circling behavior for self-avoidance, and it randomly biases its circular motion toward the apple. This agent was trained with the deprecated 4-directional absolute action space as opposed to the 3-directional relative one.
An example of a slightly more successful A3C SimpleConv agent trained with entropy regularization. It more directly targets the apples, sometimes to its own detriment. It has a strong left turn bias. Trained with:
./train_snake --type a3c --gamma 0.9 --learn_rate 0.0001 --lambda 0.07 --n_episodes 60000 --n_threads 24
Default episode length is 16 steps. Environments of non-truncated/terminated episodes are carried over into next episode.
Apologies for the low GIF quality, this test run includes added noise
From my limited trials, it seems that PPO in this environment has a higher upper limit on reward, whereas A3C didn't see much benefit from training beyond 1M steps. More benchmarking to come. This one was trained with linear learn rate annealing, Adam-Rel 8, and 3M steps.
./train_snake --type ppo --n_threads 24 --batch_size 128 --lambda 0.02 --learn_rate 0.00025 --gamma 0.95 --n_steps_per_cycle 4096 --n_steps 3000000 --learn_rate_final 0
Here are 9 runs all from >1.0 standard deviation above the average performance for this model:
I chose C++ because I think it is well suited for building multithreaded applications but also because I want it to be able to interface directly with high performance methods/algos also written in C++. Part of the appeal of RL (to me) is that it can be applied to many different types of control and optimization problems. For training/evaluation/reward purposes, it is useful to be able to perform CPU or sequentially bound operations as fast as possible. There are many benchmark environments available with Python interfaces but my eventual goal is to apply this to my own custom environments. In addition to these considerations, the PyTorch RL support/documentation is fairly limited and lacking structure, so I found it difficult to use as a starting point.
Benchmark speed vs n_threads for A3CAdd model checkpoints/saving/loading- Print critic's value estimation for every state during test demo
- (Truly) Exhaustive tests
and fix A3C regression - Sample failure modes of Snake and consider frame stacking, different encoding.
- plot attention map
implement a3c (now currently a2c)Critic network and baseline subtraction- Visualization:
- basic training loss plot (split into reward and entropy terms)
trained model behavior- save as GIF/video (automatically)
- action distributions per state
- More appropriate models for encoding observation space
CNN (priority)- RNN
- GNN <3
- DQN
- likely important for SnakeEnv, which is essentially Cliff World
Abstract away specific NN classes- Exhaustive comparison of methods
- Break out epsilon annealing into simple class (now deprioritized by entropy loss)
Footnotes
-
Sutton, R. S., McAllester, D., Singh, S., & Mansour, Y. (1999). Policy gradient methods for reinforcement learning with function approximation. In Advances in Neural Information Processing Systems (Vol. 12). MIT Press. ↩
-
Williams, R. J., & Peng, J. (1991). Function optimization using connectionist reinforcement learning algorithms. Connection Science, 3(3), 241–268. https://doi.org/10.1080/09540099108946587 ↩
-
Mnih, V., Badia, A. P., Mirza, M., Graves, A., Lillicrap, T., Harley, T., Silver, D., & Kavukcuoglu, K. (2016). Asynchronous methods for deep reinforcement learning. arXiv preprint arXiv:1602.01783. https://doi.org/10.48550/arXiv.1602.01783 ↩
-
Schulman, J., Moritz, P., Levine, S., Jordan, M., & Abbeel, P. (2018). High-dimensional continuous control using generalized advantage estimation. arXiv preprint arXiv:1506.02438. https://doi.org/10.48550/arXiv.1506.02438 ↩
-
Schulman, J., Wolski, F., Dhariwal, P., Radford, A., & Klimov, O. (2017). Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347. https://doi.org/10.48550/arXiv.1707.06347 ↩
-
Woo, S., Park, J., Lee, J.-Y., & Kweon, I. S. (2018). CBAM: Convolutional block attention module. arXiv preprint arXiv:1807.06521. https://doi.org/10.48550/arXiv.1807.06521 ↩
-
Huang, G., Liu, Z., Van Der Maaten, L., & Weinberger, K. Q. (2018). Densely connected convolutional networks. arXiv preprint arXiv:1608.06993. https://doi.org/10.48550/arXiv.1608.06993 ↩
-
Ellis, B., Jackson, M. T., Lupu, A., Goldie, A. D., Fellows, M., Whiteson, S., & Foerster, J. (2024, December). Adam on local time: Addressing nonstationarity in RL with relative Adam timesteps. arXiv. https://doi.org/10.48550/arXiv.2412.17113 ↩



