## Example using TD3 to learn a policy

This notebook illustrates how to use TD3 to learn a policy for a simple `ElectricGridEnv`. The environment is the same as the one used in [RL_Complex_DEMO_external_agent](RL_Complex_DEMO_external_agent.ipynb). Furthermore, we introduce how to use `Wandb.jl` to log the training process. For more information on how to customize logging in `Wandb.jl`, please refer to the [documentation](https://avik-pal.github.io/Wandb.jl/stable/).

In [2]:
using ElectricGrid
using ReinforcementLearning
using Flux
using Flux.Losses
using StableRNGs
using IntervalSets
using Zygote: ignore
using Logging
using Wandb

td3_src_dir = joinpath(dirname(pathof(ElectricGrid)))
include(td3_src_dir * "/agent_td3.jl")


# without using wandb
# one using just for wandb
# make sure to have a wandb account and be logged in
# https://docs.wandb.ai/quickstart
logger = WandbLogger(
    # Provide a project name and an entity name
    project="TD3",
    # optionally provide a team name if it is created
    entity="electricgrid-jl",
    # optionally provide a run name
    name="train with TD3",
    # optionally provide a config
    config=Dict(
        "lr" => 3e-5,
        ),
)

# override the global logger so that you can easily log
# using `@info` 
# e.g. @info "metrics" actor_loss=loss  
global_logger(logger)


wandb: Currently logged in as: vikasc-nitk (electricgrid-jl). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.15.5
wandb: Run data is saved locally in /data/cvikas/Projects/ElectricGrid.jl/examples/notebooks/wandb/run-20230714_090659-gmesrzpw
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run train with TD3
wandb: ⭐️ View project at https://wandb.ai/electricgrid-jl/TD3
wandb: 🚀 View run at https://wandb.ai/electricgrid-jl/TD3/runs/gmesrzpw
└ @ Wandb /upb/users/c/cvikas/profiles/unix/cs/.julia/packages/Wandb/JKy1e/src/main.jl:37


Base.CoreLogging.SimpleLogger(VSCodeServer.IJuliaCore.IJuliaStdio{Base.PipeEndpoint, typeof(VSCodeServer.io_send_callback)}(IOContext(Base.PipeEndpoint(RawFD(21) open, 0 bytes waiting)), VSCodeServer.io_send_callback), Info, Dict{Any, Int64}())

In [2]:
# using the same environment configuration 
CM = [ 
    0. 0. 1.
    0. 0. 2.
   -1. -2. 0.
]

R_load, L_load, _, _ = ParallelLoadImpedance(50e3, 0.95, 230)

parameters = Dict{Any, Any}(
                    "source" => Any[
                                    Dict{Any, Any}(
                                        "pwr" => 200e3,
                                        "control_type" => "RL",
                                        "mode" => "my_agent",
                                        "fltr" => "L",
                                        #"L1" => 0.0008,
                                        ),
                                    Dict{Any, Any}(
                                        "pwr" => 200e3,
                                        "fltr" => "LC",
                                        "control_type" => "classic",
                                        "mode" => "Droop",),
                                    ],
                    "grid" => Dict{Any, Any}(
                        "phase" => 3,
                        "ramp_end" => 0.04,)
    )


function reference(t)
    if t < 0.04
        return [0.0, 0.0, 0.0]
    end

    θ = 2*pi*50*t
    θph = [θ; θ - 120π/180; θ + 120π/180]
    return +10 * cos.(θph) 
end


featurize_ddpg = function(state, env, name)
    if name == "my_agent"
        norm_ref = env.nc.parameters["source"][1]["i_limit"]
        state = vcat(state, reference(env.t)/norm_ref)
    end
end


function reward_function(env, name = nothing)
    if name == "classic"
        return 0        
    else
        state_to_control_1 = env.state[findfirst(x -> x == "source1_i_L1_a", env.state_ids)]
        state_to_control_2 = env.state[findfirst(x -> x == "source1_i_L1_b", env.state_ids)]
        state_to_control_3 = env.state[findfirst(x -> x == "source1_i_L1_c", env.state_ids)]

        state_to_control = [state_to_control_1, state_to_control_2, state_to_control_3]

        if any(abs.(state_to_control).>1)
            return -1
        else

            refs = reference(env.t)
            norm_ref = env.nc.parameters["source"][1]["i_limit"]          
            r = 1-1/3*(sum((abs.(refs/norm_ref - state_to_control)/2).^0.5))
            return r 
        end
    end

end


env = ElectricGridEnv(
    #CM =  CM,
    parameters = parameters,
    t_end = 1,
    reward_function = reward_function,
    featurize = featurize_ddpg,
    action_delay = 0,
    verbosity = 0
);


******************************************************************************
This program contains Ipopt, a library for large-scale nonlinear optimization.
 Ipopt is released as open source code under the Eclipse Public License (EPL).
         For more information visit https://github.com/coin-or/Ipopt
******************************************************************************



We have slightly modified the original TD3 implementation in Julia to make it compatible with the `ElectricGrid.jl` framework.  

Details about the TD3 algorithm can be found in the paper [Addressing Function Approximation Error in Actor-Critic Methods](https://arxiv.org/abs/1802.09477) and the original implementation in Julia can be found [here](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/blob/2e1de3e5b6b8224f50b3d11bba7e1d2d72c6ef7c/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/td3.jl).

In [3]:
rng = StableRNG(0)
init = glorot_uniform(rng)

# specify number of states and actions to be controlled by the agent
ns = length(ElectricGrid.state(env, "my_agent"))
na = length(env.agent_dict["my_agent"]["action_ids"])

CreateActor() = Chain(
    Dense(ns, 32, relu; init = init),
    Dense(32, 32, relu; init = init),
    Dense(32, na, tanh; init = init)
)

CreateCriticModel() = Chain(
    Dense(ns + na, 64, relu; init = init),
    Dense(64, 64, relu; init = init),
    Dense(64, 1; init = init)
)


# create twin critic models
CreateCritic() = TD3Critic(
    CreateCriticModel(),
    CreateCriticModel(),
)

# learning_rate = logger.config["lr"]
learning_rate = 3e-5

TD3_agent = Agent(
    policy = TD3Policy(
        behavior_actor = NeuralNetworkApproximator(
            model = CreateActor(),
            optimizer = ADAM(learning_rate),
        ),
        behavior_critic = NeuralNetworkApproximator(
            model = CreateCritic(),
            optimizer = ADAM(learning_rate),
        ),
        target_actor = NeuralNetworkApproximator(
            model = CreateActor(),
            optimizer = ADAM(learning_rate),
        ),
        target_critic = NeuralNetworkApproximator(
            model = CreateCritic(),
            optimizer = ADAM(learning_rate),
        ),
        γ = 0.99f0,
        ρ = 0.995f0,
        batch_size = 64,
        start_steps = 10,
        # start_steps = -1,
        start_policy = RandomPolicy(-1.0..1.0; rng = rng),
        update_after = 10,
        update_freq = 1,
        policy_freq = 2,
        target_act_limit = 1.0,
        target_act_noise = 0.1,
        act_limit = 1.0,
        act_noise = 0.05,
        rng = rng,
    ),

    trajectory = CircularArraySARTTrajectory(
            capacity = 10_000,
            state = Vector{Float32} => (ns,),
            action = Float32 => (na, ),
    ),
);

In [4]:
td3_agent = Dict("my_agent" => TD3_agent)

controllers = SetupAgents(env, td3_agent)

Learn(
    controllers,
    env, 
    num_episodes = 100,
);



[32mProgress:   2%|▉                                        |  ETA: 0:42:12[39m[K

[32mProgress:  12%|████▉                                    |  ETA: 0:07:15[39m[K

[32mProgress:  14%|█████▊                                   |  ETA: 0:06:15[39m[K

[32mProgress:  16%|██████▌                                  |  ETA: 0:05:30[39m[K

[32mProgress:  18%|███████▍                                 |  ETA: 0:04:54[39m[K

[32mProgress:  20%|████████▎                                |  ETA: 0:04:26[39m[K

[32mProgress:  22%|█████████                                |  ETA: 0:04:02[39m[K

[32mProgress:  24%|█████████▉                               |  ETA: 0:03:41[39m[K

[32mProgress:  26%|██████████▋                              |  ETA: 0:03:24[39m[K

[32mProgress:  28%|███████████▌                             |  ETA: 0:03:09[39m[K

[32mProgress:  30%|████████████▎                            |  ETA: 0:02:55[39m[K

[32mProgress:  40%|████████████████▍                        |  ETA: 0:01:54[39m[K

[32mProgress:  61%|█████████████████████████                |  ETA: 0:00:49[39m[K

[32mProgress:  83%|██████████████████████████████████       |  ETA: 0:00:16[39m[K

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:19[39m[K


                 ⠀⠀⠀⠀⠀⠀⠀⠀⠀[97;1mTotal reward per episode[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀         
                 [38;5;8m┌────────────────────────────────────────┐[0m         
         [38;5;8m175.056[0m [38;5;8m│[0m[38;5;4m⡇[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;2mclassic[0m 
                [38;5;8m[0m [38;5;8m│[0m[38;5;4m⢸[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;4mmy_agent[0m
                [38;5;8m[0m [38;5;8m│[0m[38;5;4m⢸[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m        
                [38;5;8m[0m [38;5;8m│[0m[38;5;4m⠸[0m[38;5;4m⡀[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m        
                [38;5;8m[0m [38;5;8m│[0m⠀[38;5;4m⡇[0m[38;5;4m⢰[0m[38;5;4m⠤[0m[38;5;4m⣀[0m[38;5;4m⢄[0m[38;5;4m⣀[0m[38;5;4m⣀[0m[38;5;4m⣀[0m[38;5;4m⣠[0m⠀[38;5;4m⡄[0m[38;5;4m⡀[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m        
                [38;5;8m

In [5]:
# close wandb session
close(logger)

wandb: Waiting for W&B process to finish... (success).
wandb: - 0.005 MB of 0.005 MB uploaded (0.000 MB deduped)wandb: \ 0.005 MB of 0.005 MB uploaded (0.000 MB deduped)wandb: 
wandb: Run history:
wandb:       Episode/episode ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:    metrics/actor_loss █████▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▃▂▁▁▁
wandb:   metrics/critic_loss ▁▁▁▁▁▁▂▁▁▂▁▁▁▁▁▂▁▁▁▂▁▁▂▁▁▁▁▂▂▁▂▁▁▁▄▁▆▁▁█
wandb: total_timesteps/tstep █▆▄███████▇█▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: 
wandb: Run summary:
wandb:       Episode/episode 1
wandb:    metrics/actor_loss -25.94836
wandb:   metrics/critic_loss 1.19431
wandb: total_timesteps/tstep 14
wandb: 
wandb: 🚀 View run train with TD3 at: https://wandb.ai/electricgrid-jl/TD3/runs/qme67myo
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20230713_173906-qme67myo/logs


Python: None

We can track the training metrics on wandb website in realtime. Here we provide the plots of the actor and critic loss during training for illustration purposes:
<!-- put images side by side-->
<!-- ![Wandb metrics](figures/TD3_actor_loss.svg) -->
<img src="figures/TD3_actor_loss.svg"
        style="width:700px;"/>
<img src="figures/TD3_critic_loss.svg"
        style="width:700px;"/>
<!-- ![Wandb metrics](figures/TD3_critic_loss.svg) -->

In [None]:
hook = DataHook(collect_state_ids = env.state_ids,
                collect_action_ids = env.action_ids)

hook = Simulate(controllers, env, hook=hook)


RenderHookResults(hook = hook,
                    states_to_plot  = env.state_ids,
                    actions_to_plot = env.action_ids,
                    plot_reward=true)