This package implements the deep correction method [1] for solving reinforcement learning problems. The user should define the problem according to the POMDPs.jl interface.
[1] M. Bouton, K. Julian, A. Nakhaei, K. Fujimura, and M. J. Kochenderfer, “Utility decomposition with deep corrections for scalable planning under uncertainty,” in International Conference on Autonomous Agents and Multiagent Systems (AAMAS), 2018.
using Pkg
Pkg.add("POMDPs")
POMDPs.add_registry() # to get DeepQLearning, and RLInterface
Pkg.add(PackageSpec(url="https://github.com/sisl/DeepCorrections.jl"))using POMDPs
using DeepCorrections
using Flux # for model definition
using DeepQLearning # for underlying DQN solver
using POMDPModels # for gridworld
mdp = SimpleGridWorld()
function my_low_fidelity_values(problem::SimpleGridWorld, s)
return ones(n_actions(problem)) # dummy example, should return an action value vector
end
model = Chain(Dense(2, 32, relu), Dense(32, n_actions(mdp))) # input is 2 dimensional, x,y positions in grid world
dqn_solver = DeepQLearningSolver(qnetwork = model, verbose=true) # see DQN docs for all the parameters
solver = DeepCorrectionSolver(dqn = dqn_solver,
lowfi_values = my_low_fidelity_values)
policy = solve(solver, problem)The type DeepCorrectionSolver relies on the DeepQLearningSolver type defined in DeepQLearning.jl. The deep correction solver supports all the options available in for the DeepQLearningSolver.
solve returns a DeepCorrectionPolicy object. It can be used like any policy in the POMDPs.jl interface.
Low fidelity value estimation:
To provide the low fidelity value function to the solver the user can use the lowfi_values option specify when initializing the solver. It can be a function or a policy. If this is a function f, f(mdp, s) will be called to estimate the value. If this is a policy, actionvalues(policy, s) will be called. See the documentation in POMDPPolicies for more details on actionvalues.
The output should be a vector of size n_actions(mdp). The actions are assumed to be ordered according to the function action_index implemented by the problem writer.
Correction method:
Two default correction methods are available:
- additive correction:
Q_lo(s, a) + delta(s, a), whereQ_lois the result oflowfi_valuesanddeltais the correction network. - multiplicative correction:
Q_lo(s, a)delta(s, a)
An additional constant weight can be used in the correction method using the option correction_weight in the solver. The user can write its own correction method via the correction_method option. It can be a function or an object. If this is a function f, f(mdp, q_lo, q_corr, correction_weight) will be called to estimate the value. If this is an object o, correction(o, mdp, q_lo, q_corr, correction_weight) will be called.
The underlying implementation relies on tensorflow and static graphs, the correction method implemented must support tensor inputs and be tensorflow friendly. The signature should look like this:
multiplicative_correction(problem::Union{POMDP, MDP}, q_lo::Q, q_corr::Q, weight::Float64) where Q <:Union{Array{Float64}, Tensor}