# Deep Reinforcement Learning

In [1]:
#include <backprop_tools/operations/cpu.h>
#include <backprop_tools/nn/operations_cpu.h>
#include <backprop_tools/rl/environments/operations_generic.h>
#include <backprop_tools/nn_models/operations_cpu.h>
#include <backprop_tools/rl/operations_generic.h>
namespace bpt = backprop_tools;

In [2]:
#include <backprop_tools/rl/utils/evaluation.h>

In [3]:
using T = float;
using DEVICE = bpt::devices::DefaultCPU;
using TI = typename DEVICE::index_t;

In [4]:
using ENVIRONMENT_PARAMETERS = bpt::rl::environments::pendulum::DefaultParameters<T>;
using ENVIRONMENT_SPEC = bpt::rl::environments::pendulum::Specification<T, TI, ENVIRONMENT_PARAMETERS>;
using ENVIRONMENT = bpt::rl::environments::Pendulum<ENVIRONMENT_SPEC>;

In [5]:
struct TD3_PENDULUM_PARAMETERS: bpt::rl::algorithms::td3::DefaultParameters<T, TI>{
    constexpr static typename DEVICE::index_t CRITIC_BATCH_SIZE = 100;
    constexpr static typename DEVICE::index_t ACTOR_BATCH_SIZE = 100;
};
constexpr TI STEP_LIMIT = 10000;
constexpr TI REPLAY_BUFFER_CAP = STEP_LIMIT;
constexpr int N_WARMUP_STEPS = TD3_PENDULUM_PARAMETERS::ACTOR_BATCH_SIZE;
constexpr TI EPISODE_STEP_LIMIT = 200;
constexpr TI ACTOR_NUM_LAYERS = 3;
constexpr TI ACTOR_HIDDEN_DIM = 64;
constexpr TI CRITIC_NUM_LAYERS = 3;
constexpr TI CRITIC_HIDDEN_DIM = 64;
constexpr auto ACTOR_ACTIVATION_FUNCTION = bpt::nn::activation_functions::RELU;
constexpr auto CRITIC_ACTIVATION_FUNCTION = bpt::nn::activation_functions::RELU;
constexpr auto ACTOR_ACTIVATION_FUNCTION_OUTPUT = bpt::nn::activation_functions::TANH;
constexpr auto CRITIC_ACTIVATION_FUNCTION_OUTPUT = bpt::nn::activation_functions::IDENTITY;
using TD3_PARAMETERS = TD3_PENDULUM_PARAMETERS;

In [6]:
using ACTOR_STRUCTURE_SPEC = bpt::nn_models::mlp::StructureSpecification<T, TI, ENVIRONMENT::OBSERVATION_DIM, ENVIRONMENT::ACTION_DIM, ACTOR_NUM_LAYERS, ACTOR_HIDDEN_DIM, ACTOR_ACTIVATION_FUNCTION, ACTOR_ACTIVATION_FUNCTION_OUTPUT, TD3_PARAMETERS::ACTOR_BATCH_SIZE>;
using CRITIC_STRUCTURE_SPEC = bpt::nn_models::mlp::StructureSpecification<T, TI, ENVIRONMENT::OBSERVATION_DIM + ENVIRONMENT::ACTION_DIM, 1, CRITIC_NUM_LAYERS, CRITIC_HIDDEN_DIM, CRITIC_ACTIVATION_FUNCTION, CRITIC_ACTIVATION_FUNCTION_OUTPUT, TD3_PARAMETERS::CRITIC_BATCH_SIZE>;
using OPTIMIZER_PARAMETERS = typename bpt::nn::optimizers::adam::DefaultParametersTorch<T>;

using OPTIMIZER = bpt::nn::optimizers::Adam<OPTIMIZER_PARAMETERS>;
using ACTOR_NETWORK_SPEC = bpt::nn_models::mlp::AdamSpecification<ACTOR_STRUCTURE_SPEC>;
using ACTOR_NETWORK_TYPE = bpt::nn_models::mlp::NeuralNetworkAdam<ACTOR_NETWORK_SPEC>;

using ACTOR_TARGET_NETWORK_SPEC = bpt::nn_models::mlp::InferenceSpecification<ACTOR_STRUCTURE_SPEC>;
using ACTOR_TARGET_NETWORK_TYPE = backprop_tools::nn_models::mlp::NeuralNetwork<ACTOR_TARGET_NETWORK_SPEC>;

using CRITIC_NETWORK_SPEC = bpt::nn_models::mlp::AdamSpecification<CRITIC_STRUCTURE_SPEC>;
using CRITIC_NETWORK_TYPE = backprop_tools::nn_models::mlp::NeuralNetworkAdam<CRITIC_NETWORK_SPEC>;

using CRITIC_TARGET_NETWORK_SPEC = backprop_tools::nn_models::mlp::InferenceSpecification<CRITIC_STRUCTURE_SPEC>;
using CRITIC_TARGET_NETWORK_TYPE = backprop_tools::nn_models::mlp::NeuralNetwork<CRITIC_TARGET_NETWORK_SPEC>;

using TD3_SPEC = bpt::rl::algorithms::td3::Specification<T, DEVICE::index_t, ENVIRONMENT, ACTOR_NETWORK_TYPE, ACTOR_TARGET_NETWORK_TYPE, CRITIC_NETWORK_TYPE, CRITIC_TARGET_NETWORK_TYPE, TD3_PARAMETERS>;
using ACTOR_CRITIC_TYPE = bpt::rl::algorithms::td3::ActorCritic<TD3_SPEC>;

using OFF_POLICY_RUNNER_SPEC = bpt::rl::components::off_policy_runner::Specification<
    T,
    TI,
    ENVIRONMENT,
    1,
    REPLAY_BUFFER_CAP,
    EPISODE_STEP_LIMIT,
    bpt::rl::components::off_policy_runner::DefaultParameters<T>
>;
using OFF_POLICY_RUNNER_TYPE = bpt::rl::components::OffPolicyRunner<OFF_POLICY_RUNNER_SPEC>;

In [7]:
static_assert(ACTOR_CRITIC_TYPE::SPEC::PARAMETERS::ACTOR_BATCH_SIZE == ACTOR_CRITIC_TYPE::SPEC::PARAMETERS::CRITIC_BATCH_SIZE);

In [8]:
DEVICE device;
OPTIMIZER optimizer;
auto rng = bpt::random::default_engine(typename DEVICE::SPEC::RANDOM{}, 1);
bool ui = false;

In [9]:
ACTOR_CRITIC_TYPE actor_critic;
bpt::malloc(device, actor_critic);
bpt::init(device, actor_critic, optimizer, rng);

In [10]:
OFF_POLICY_RUNNER_TYPE off_policy_runner;
bpt::malloc(device, off_policy_runner);
ENVIRONMENT envs[decltype(off_policy_runner)::N_ENVIRONMENTS];
bpt::init(device, off_policy_runner, envs);

In [11]:
OFF_POLICY_RUNNER_TYPE::Batch<TD3_PARAMETERS::CRITIC_BATCH_SIZE> critic_batch;
bpt::rl::algorithms::td3::CriticTrainingBuffers<ACTOR_CRITIC_TYPE::SPEC> critic_training_buffers;
CRITIC_NETWORK_TYPE::BuffersForwardBackward<ACTOR_CRITIC_TYPE::SPEC::PARAMETERS::CRITIC_BATCH_SIZE> critic_buffers[2];
bpt::malloc(device, critic_batch);
bpt::malloc(device, critic_training_buffers);
bpt::malloc(device, critic_buffers[0]);
bpt::malloc(device, critic_buffers[1]);

In [12]:
OFF_POLICY_RUNNER_TYPE::Batch<TD3_PARAMETERS::ACTOR_BATCH_SIZE> actor_batch;
bpt::rl::algorithms::td3::ActorTrainingBuffers<ACTOR_CRITIC_TYPE::SPEC> actor_training_buffers;
ACTOR_NETWORK_TYPE::Buffers<ACTOR_CRITIC_TYPE::SPEC::PARAMETERS::ACTOR_BATCH_SIZE> actor_buffers[2];
ACTOR_NETWORK_TYPE::Buffers<OFF_POLICY_RUNNER_SPEC::N_ENVIRONMENTS> actor_buffers_eval;
bpt::malloc(device, actor_batch);
bpt::malloc(device, actor_training_buffers);
bpt::malloc(device, actor_buffers_eval);
bpt::malloc(device, actor_buffers[0]);
bpt::malloc(device, actor_buffers[1]);

In [13]:
bpt::MatrixDynamic<bpt::matrix::Specification<T, TI, 1, ENVIRONMENT::OBSERVATION_DIM>> observations_mean;
bpt::MatrixDynamic<bpt::matrix::Specification<T, TI, 1, ENVIRONMENT::OBSERVATION_DIM>> observations_std;
bpt::malloc(device, observations_mean);
bpt::malloc(device, observations_std);
bpt::set_all(device, observations_mean, 0);
bpt::set_all(device, observations_std, 1);

In [None]:
auto start_time = std::chrono::high_resolution_clock::now();

for(int step_i = 0; step_i < STEP_LIMIT; step_i+=OFF_POLICY_RUNNER_SPEC::N_ENVIRONMENTS){
    bpt::set_step(device, device.logger, step_i);
    bpt::step(device, off_policy_runner, actor_critic.actor, actor_buffers_eval, rng);

    if(step_i > N_WARMUP_STEPS){
        if(step_i % 10 == 0){
            auto current_time = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> elapsed_seconds = current_time - start_time;
            std::cout << "step_i: " << step_i << " " << elapsed_seconds.count() << "s" << std::endl;
        }

        for(int critic_i = 0; critic_i < 2; critic_i++){
            bpt::target_action_noise(device, actor_critic, critic_training_buffers.target_next_action_noise, rng);
            bpt::gather_batch(device, off_policy_runner, critic_batch, rng);
            bpt::train_critic(device, actor_critic, critic_i == 0 ? actor_critic.critic_1 : actor_critic.critic_2, critic_batch, optimizer, actor_buffers[critic_i], critic_buffers[critic_i], critic_training_buffers);
        }

//            T critic_1_loss = bpt::train_critic(device, actor_critic, actor_critic.critic_1, off_policy_runner.replay_buffer, rng);
//            bpt::train_critic(device, actor_critic, actor_critic.critic_2, off_policy_runner.replay_buffer, rng);
//            std::cout << "Critic 1 loss: " << critic_1_loss << std::endl;
        if(step_i % 2 == 0){
            {
                bpt::gather_batch(device, off_policy_runner, actor_batch, rng);
                bpt::train_actor(device, actor_critic, actor_batch, optimizer, actor_buffers[0], critic_buffers[0], actor_training_buffers);
            }

            bpt::update_critic_targets(device, actor_critic);
            bpt::update_actor_target(device, actor_critic);
        }
    }
#ifndef BACKPROP_TOOLS_RL_ENVIRONMENTS_PENDULUM_DISABLE_EVALUATION
    if(step_i % 1000 == 0){
//            auto result = bpt::evaluate(device, envs[0], ui, actor_critic.actor, bpt::rl::utils::evaluation::Specification<1, EPISODE_STEP_LIMIT>(), rng, true);
        auto result = bpt::evaluate(device, envs[0], ui, actor_critic.actor, bpt::rl::utils::evaluation::Specification<10, EPISODE_STEP_LIMIT>(), observations_mean, observations_std, rng);
        std::cout << "Mean return: " << result.mean << std::endl;
        bpt::add_scalar(device, device.logger, "mean_return", result.mean);
//            if(step_i >= 6000){
//                ASSERT_GT(mean_return, -1000);
//            }
//            if(step_i >= 14000){
//                ASSERT_GT(mean_return, -400);
//            }
    }
#endif
}


Mean return: -1479.93
step_i: 110 1.68733s
step_i: 120 3.30863s
step_i: 130 4.81232s
step_i: 140 6.30074s
step_i: 150 7.77397s
step_i: 160 9.24774s
step_i: 170 10.7217s
step_i: 180 12.2024s
step_i: 190 13.6768s
step_i: 200 15.1508s
step_i: 210 16.6292s
step_i: 220 18.1697s
step_i: 230 19.7083s
step_i: 240 21.2014s
step_i: 250 22.6951s
step_i: 260 24.1898s
step_i: 270 25.6633s
step_i: 280 27.1531s
step_i: 290 28.6264s
step_i: 300 30.1016s
step_i: 310 31.5771s
step_i: 320 33.0532s
step_i: 330 34.5327s
step_i: 340 36.0601s
step_i: 350 37.5352s
step_i: 360 39.01s
step_i: 370 40.7928s
step_i: 380 42.4236s
step_i: 390 44.044s
step_i: 400 45.528s
step_i: 410 47.0015s
step_i: 420 48.4798s
step_i: 430 49.9752s
step_i: 440 51.4512s
step_i: 450 52.9286s
step_i: 460 54.4025s
step_i: 470 55.946s
step_i: 480 57.4692s
step_i: 490 59.0917s
step_i: 500 60.5789s
step_i: 510 62.2005s
step_i: 520 64.6572s
step_i: 530 67.4605s
step_i: 540 70.2902s
step_i: 550 73.1081s
step_i: 560 75.6886s
step_i: 570 78.21