Developers only need three steps to implement an RL algorithm with **RLLTE**:

1. Selection an algorithm prototype;
2. Select desired modules;
3. Write a update function.

The following example illustrates how to write an Advantage Actor-Critic (A2C) agent to solve Atari games.

# 1. Set prototype

Firstly, we select `OnPolicyAgent` as the prototype

In [1]:
!pip install rllte-core

In [2]:
from rllte.common.prototype import OnPolicyAgent

class A2C(OnPolicyAgent):
    def __init__(self, env, tag, device, num_steps):
        # here we only use four arguments
        super().__init__(env=env, tag=tag, device=device, num_steps=num_steps)

pygame 2.4.0 (SDL 2.26.4, Python 3.8.16)
Hello from the pygame community. https://www.pygame.org/contribute.html


# 2. Set necessary modules

Now we need an `encoder` to process observations, a learnable `policy` to generate actions, and a `storage` to store and sample experiences.

In [3]:
from rllte.xploit.encoder import MnihCnnEncoder
from rllte.xploit.policy import OnPolicySharedActorCritic
from rllte.xploit.storage import VanillaRolloutStorage
from rllte.xplore.distribution import Categorical

# 3. Set update function

Run the `.describe` function of the selected policy and you will see the following output:

In [4]:
OnPolicySharedActorCritic.describe()



Name       : OnPolicySharedActorCritic
Structure  : self.encoder (shared by actor and critic), self.actor, self.critic
Forward    : obs -> self.encoder -> self.actor -> actions
           : obs -> self.encoder -> self.critic -> values
           : actions -> log_probs
Optimizers : self.optimizers['opt'] -> (self.encoder, self.actor, self.critic)




This will illustrate the structure of the policy and indicate the optimizable parts. Finally, merge these modules and write a `.update` function:

In [5]:
import dis
from torch import nn
import torch as th

class A2C(OnPolicyAgent):
    def __init__(self, env, tag, seed, device, num_steps) -> None:
        super().__init__(env=env, tag=tag, seed=seed, device=device, num_steps=num_steps)
        # create modules
        encoder = MnihCnnEncoder(observation_space=env.observation_space, feature_dim=512)
        policy = OnPolicySharedActorCritic(observation_space=env.observation_space,
                                           action_space=env.action_space,
                                           feature_dim=512,
                                           opt_class=th.optim.Adam,
                                           opt_kwargs=dict(lr=2.5e-4, eps=1e-5),
                                           init_fn="xavier_uniform"
                                           )
        storage = VanillaRolloutStorage(observation_space=env.observation_space,
                                        action_space=env.action_space,
                                        device=device,
                                        storage_size=self.num_steps,
                                        num_envs=self.num_envs,
                                        batch_size=256
                                        )
        dist = Categorical()
        # set all the modules
        self.set(encoder=encoder, policy=policy, storage=storage, distribution=dist)
    
    def update(self):
        for _ in range(4):
            for batch in self.storage.sample():
                # evaluate the sampled actions
                new_values, new_log_probs, entropy = self.policy.evaluate_actions(obs=batch.observations, actions=batch.actions)
                # policy loss part
                policy_loss = - (batch.adv_targ * new_log_probs).mean()
                # value loss part
                value_loss = 0.5 * (new_values.flatten() - batch.returns).pow(2).mean()
                # update
                self.policy.optimizers['opt'].zero_grad(set_to_none=True)
                (value_loss * 0.5 + policy_loss - entropy * 0.01).backward()
                nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
                self.policy.optimizers['opt'].step()

# 4. Start training

Now we can start training by

In [6]:
from rllte.env import make_atari_env
if __name__ == "__main__":
    device = "cuda"
    env = make_atari_env("AlienNoFrameskip-v4", num_envs=8, seed=0, device=device)
    agent = A2C(env=env, tag="a2c_atari", seed=0, device=device, num_steps=128)
    agent.train(num_train_steps=10000)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


[08/28/2023 05:43:19 PM] - [[1m[34mINFO.[0m] - Invoking RLLTE Engine...
[08/28/2023 05:43:19 PM] - [[1m[34mINFO.[0m] - Tag               : a2c_atari
[08/28/2023 05:43:19 PM] - [[1m[34mINFO.[0m] - Device            : NVIDIA GeForce RTX 3090
[08/28/2023 05:43:19 PM] - [[1m[33mDEBUG[0m] - Agent             : A2C
[08/28/2023 05:43:19 PM] - [[1m[33mDEBUG[0m] - Encoder           : MnihCnnEncoder
[08/28/2023 05:43:19 PM] - [[1m[33mDEBUG[0m] - Policy            : OnPolicySharedActorCritic
[08/28/2023 05:43:19 PM] - [[1m[33mDEBUG[0m] - Storage           : VanillaRolloutStorage
[08/28/2023 05:43:19 PM] - [[1m[33mDEBUG[0m] - Distribution      : Categorical
[08/28/2023 05:43:19 PM] - [[1m[33mDEBUG[0m] - Augmentation      : False
[08/28/2023 05:43:19 PM] - [[1m[33mDEBUG[0m] - Intrinsic Reward  : False
[08/28/2023 05:43:22 PM] - [[1m[31mTRAIN[0m] - S: 1024        | E: 8           | L: 128         | R: 50.000      | FPS: 241.297   | T: 0:00:04    
[08/28/2023 05:43:23

As shown in this example, only a few dozen lines of code are needed to create RL agents with **RLLTE**. 