From 92785b63fd33a328c6e9bd6b22842694a02b7ab3 Mon Sep 17 00:00:00 2001 From: Wouter Caarls Date: Wed, 14 Nov 2018 19:11:46 -0200 Subject: [PATCH] SolverPolicy --- addons/lqr/include/grl/solvers/ilqg.h | 7 ++- addons/lqr/include/grl/solvers/lqr.h | 7 ++- addons/lqr/src/ilqg.cpp | 5 ++ addons/lqr/src/lqr.cpp | 5 ++ base/build.cmake | 1 + base/include/grl/policies/solver.h | 62 +++++++++++++++++++++++ base/include/grl/solver.h | 9 ++++ base/src/policies/solver.cpp | 72 +++++++++++++++++++++++++++ 8 files changed, 164 insertions(+), 4 deletions(-) create mode 100644 base/include/grl/policies/solver.h create mode 100644 base/src/policies/solver.cpp diff --git a/addons/lqr/include/grl/solvers/ilqg.h b/addons/lqr/include/grl/solvers/ilqg.h index d7e973cb..2460ec4a 100644 --- a/addons/lqr/include/grl/solvers/ilqg.h +++ b/addons/lqr/include/grl/solvers/ilqg.h @@ -47,10 +47,10 @@ namespace grl { /// Iterative Linear Quadratic Gaussian trajectory optimizer -class ILQGSolver : public Solver +class ILQGSolver : public PolicySolver { public: - TYPEINFO("solver/ilqg", "Iterative Linear Quadratic Gaussian trajectory optimizer"); + TYPEINFO("solver/policy/ilqg", "Iterative Linear Quadratic Gaussian trajectory optimizer"); protected: typedef std::vector Matrix3D; @@ -81,6 +81,9 @@ class ILQGSolver : public Solver virtual bool solve() { return false; } virtual bool solve(const Vector &x0); virtual bool resolve(double t, const Vector &xt); + + // From PolicySolver + virtual Policy *policy(); protected: bool forwardPass(const ColumnVector &x0, const Matrix &x, const Matrix &u, const Matrix3D &L, Matrix *xnew, Matrix *unew, RowVector *cnew) const; diff --git a/addons/lqr/include/grl/solvers/lqr.h b/addons/lqr/include/grl/solvers/lqr.h index b05b0908..255d4be6 100644 --- a/addons/lqr/include/grl/solvers/lqr.h +++ b/addons/lqr/include/grl/solvers/lqr.h @@ -38,10 +38,10 @@ namespace grl { /// Linear Quadratic Regulator solver -class LQRSolver : public Solver +class LQRSolver : public PolicySolver { public: - TYPEINFO("solver/lqr", "Linear Quadratic Regulator solver") + TYPEINFO("solver/policy/lqr", "Linear Quadratic Regulator solver") protected: ObservationModel *model_; @@ -60,6 +60,9 @@ class LQRSolver : public Solver virtual LQRSolver *clone() const; virtual bool solve(); + // From PolicySolver + virtual Policy *policy(); + protected: virtual int solveDARE(const Matrix &A, const Matrix &B, const Matrix &Q, const Matrix &R, Matrix *X) const; }; diff --git a/addons/lqr/src/ilqg.cpp b/addons/lqr/src/ilqg.cpp index e12d574d..3898eea7 100644 --- a/addons/lqr/src/ilqg.cpp +++ b/addons/lqr/src/ilqg.cpp @@ -334,6 +334,11 @@ bool ILQGSolver::resolve(double t, const Vector &xt) return true; } +Policy *ILQGSolver::policy() +{ + return policy_; +} + bool ILQGSolver::forwardPass(const ColumnVector &x0, const Matrix &x, const Matrix &u, const Matrix3D &L, Matrix *xnew, Matrix *unew, RowVector *cnew) const { size_t n = x0.rows(), // Number of state dimensions diff --git a/addons/lqr/src/lqr.cpp b/addons/lqr/src/lqr.cpp index b1163ba9..3153f288 100644 --- a/addons/lqr/src/lqr.cpp +++ b/addons/lqr/src/lqr.cpp @@ -147,6 +147,11 @@ bool LQRSolver::solve() return true; } +Policy *LQRSolver::policy() +{ + return policy_; +} + #define SSIZE 4096 extern "C" void sb02od_(char *DICO, char *JOBB, char *FACT, char *UPLO, char *JOBL, char *SORT, int *N, int *M, int *P, diff --git a/base/build.cmake b/base/build.cmake index c041f227..2ce4cd2a 100644 --- a/base/build.cmake +++ b/base/build.cmake @@ -77,6 +77,7 @@ add_library(${TARGET} SHARED ${SRC}/policies/noise.cpp ${SRC}/policies/multi.cpp ${SRC}/policies/multi_discrete.cpp + ${SRC}/policies/solver.cpp # ${SRC}/predictors/naf.cpp ${SRC}/predictors/model.cpp ${SRC}/predictors/sarsa.cpp diff --git a/base/include/grl/policies/solver.h b/base/include/grl/policies/solver.h new file mode 100644 index 00000000..d0a30056 --- /dev/null +++ b/base/include/grl/policies/solver.h @@ -0,0 +1,62 @@ +/** \file solver.h + * \brief Solver policy header file. + * + * \author Wouter Caarls + * \date 2018-11-14 + * + * \copyright \verbatim + * Copyright (c) 2018, Wouter Caarls + * All rights reserved. + * + * This file is part of GRL, the Generic Reinforcement Learning library. + * + * GRL is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * \endverbatim + */ + +#ifndef GRL_SOLVER_POLICY_H_ +#define GRL_SOLVER_POLICY_H_ + +#include +#include + +namespace grl +{ + +/// Policy that uses a solver to calculate the action +class SolverPolicy : public Policy +{ + public: + TYPEINFO("mapping/policy/solver", "Policy that uses a solver to calculate the action") + + protected: + PolicySolver *solver_; + int interval_, episodes_; + + public: + SolverPolicy() : solver_(NULL), interval_(1), episodes_(0) { } + + // From Configurable + virtual void request(ConfigurationRequest *config); + virtual void configure(Configuration &config); + virtual void reconfigure(const Configuration &config); + + // From Policy + virtual void act(double time, const Observation &in, Action *out); + virtual void act(const Observation &in, Action *out) const; +}; + +} + +#endif /* GRL_SOLVER_POLICY_H_ */ diff --git a/base/include/grl/solver.h b/base/include/grl/solver.h index 32f09a26..bf8c5f9f 100644 --- a/base/include/grl/solver.h +++ b/base/include/grl/solver.h @@ -29,6 +29,7 @@ #define GRL_SOLVER_H_ #include +#include namespace grl { @@ -49,6 +50,14 @@ class Solver : public Configurable virtual bool resolve(double t, const Vector &xt) { return false; } }; +/// Solver that returns a policy +class PolicySolver : public Solver +{ + public: + /// Returns policy that is adjusted by this solver. + virtual Policy *policy() = 0; +}; + } #endif /* GRL_SOLVER_H_ */ diff --git a/base/src/policies/solver.cpp b/base/src/policies/solver.cpp new file mode 100644 index 00000000..c0ca9d53 --- /dev/null +++ b/base/src/policies/solver.cpp @@ -0,0 +1,72 @@ +/** \file solver.cpp + * \brief Solver policy source file. + * + * \author Wouter Caarls + * \date 2018-11-14 + * + * \copyright \verbatim + * Copyright (c) 2018, Wouter Caarls + * All rights reserved. + * + * This file is part of GRL, the Generic Reinforcement Learning library. + * + * GRL is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * \endverbatim + */ + +#include + +using namespace grl; + +REGISTER_CONFIGURABLE(SolverPolicy) + +void SolverPolicy::request(ConfigurationRequest *config) +{ + config->push_back(CRP("interval", "Episodes between successive solutions (0=asynchronous)", interval_, CRP::Configuration, 1, INT_MAX)); + config->push_back(CRP("solver", "solver/policy", "Solver that calculates the policy", solver_)); +} + +void SolverPolicy::configure(Configuration &config) +{ + solver_ = (PolicySolver*)config["solver"].ptr(); + interval_ = config["interval"]; + + episodes_ = 0; +} + +void SolverPolicy::reconfigure(const Configuration &config) +{ + if (config.has("action") && config["action"].str() == "reset") + episodes_ = 0; +} + +void SolverPolicy::act(double time, const Observation &in, Action *out) +{ + if (time == 0.) + { + episodes_++; + if (interval_ && (episodes_%interval_)==0) + solver_->solve(); + solver_->solve(in); + } + else + solver_->resolve(time, in); + + return solver_->policy()->act(in, out); +} + +void SolverPolicy::act(const Observation &in, Action *out) const +{ + return solver_->policy()->act(in, out); +}