Skip to content

Commit

Permalink
SolverPolicy
Browse files Browse the repository at this point in the history
  • Loading branch information
wcaarls committed Nov 14, 2018
1 parent c0865b2 commit 92785b6
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 4 deletions.
7 changes: 5 additions & 2 deletions addons/lqr/include/grl/solvers/ilqg.h
Expand Up @@ -47,10 +47,10 @@
namespace grl {

/// Iterative Linear Quadratic Gaussian trajectory optimizer
class ILQGSolver : public Solver
class ILQGSolver : public PolicySolver
{
public:
TYPEINFO("solver/ilqg", "Iterative Linear Quadratic Gaussian trajectory optimizer");
TYPEINFO("solver/policy/ilqg", "Iterative Linear Quadratic Gaussian trajectory optimizer");

protected:
typedef std::vector<Matrix> Matrix3D;
Expand Down Expand Up @@ -81,6 +81,9 @@ class ILQGSolver : public Solver
virtual bool solve() { return false; }
virtual bool solve(const Vector &x0);
virtual bool resolve(double t, const Vector &xt);

// From PolicySolver
virtual Policy *policy();

protected:
bool forwardPass(const ColumnVector &x0, const Matrix &x, const Matrix &u, const Matrix3D &L, Matrix *xnew, Matrix *unew, RowVector *cnew) const;
Expand Down
7 changes: 5 additions & 2 deletions addons/lqr/include/grl/solvers/lqr.h
Expand Up @@ -38,10 +38,10 @@ namespace grl
{

/// Linear Quadratic Regulator solver
class LQRSolver : public Solver
class LQRSolver : public PolicySolver
{
public:
TYPEINFO("solver/lqr", "Linear Quadratic Regulator solver")
TYPEINFO("solver/policy/lqr", "Linear Quadratic Regulator solver")

protected:
ObservationModel *model_;
Expand All @@ -60,6 +60,9 @@ class LQRSolver : public Solver
virtual LQRSolver *clone() const;
virtual bool solve();

// From PolicySolver
virtual Policy *policy();

protected:
virtual int solveDARE(const Matrix &A, const Matrix &B, const Matrix &Q, const Matrix &R, Matrix *X) const;
};
Expand Down
5 changes: 5 additions & 0 deletions addons/lqr/src/ilqg.cpp
Expand Up @@ -334,6 +334,11 @@ bool ILQGSolver::resolve(double t, const Vector &xt)
return true;
}

Policy *ILQGSolver::policy()
{
return policy_;
}

bool ILQGSolver::forwardPass(const ColumnVector &x0, const Matrix &x, const Matrix &u, const Matrix3D &L, Matrix *xnew, Matrix *unew, RowVector *cnew) const
{
size_t n = x0.rows(), // Number of state dimensions
Expand Down
5 changes: 5 additions & 0 deletions addons/lqr/src/lqr.cpp
Expand Up @@ -147,6 +147,11 @@ bool LQRSolver::solve()
return true;
}

Policy *LQRSolver::policy()
{
return policy_;
}

#define SSIZE 4096
extern "C" void sb02od_(char *DICO, char *JOBB, char *FACT, char *UPLO,
char *JOBL, char *SORT, int *N, int *M, int *P,
Expand Down
1 change: 1 addition & 0 deletions base/build.cmake
Expand Up @@ -77,6 +77,7 @@ add_library(${TARGET} SHARED
${SRC}/policies/noise.cpp
${SRC}/policies/multi.cpp
${SRC}/policies/multi_discrete.cpp
${SRC}/policies/solver.cpp
# ${SRC}/predictors/naf.cpp
${SRC}/predictors/model.cpp
${SRC}/predictors/sarsa.cpp
Expand Down
62 changes: 62 additions & 0 deletions base/include/grl/policies/solver.h
@@ -0,0 +1,62 @@
/** \file solver.h
* \brief Solver policy header file.
*
* \author Wouter Caarls <wouter@caarls.org>
* \date 2018-11-14
*
* \copyright \verbatim
* Copyright (c) 2018, Wouter Caarls
* All rights reserved.
*
* This file is part of GRL, the Generic Reinforcement Learning library.
*
* GRL is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
* \endverbatim
*/

#ifndef GRL_SOLVER_POLICY_H_
#define GRL_SOLVER_POLICY_H_

#include <grl/policy.h>
#include <grl/solver.h>

namespace grl
{

/// Policy that uses a solver to calculate the action
class SolverPolicy : public Policy
{
public:
TYPEINFO("mapping/policy/solver", "Policy that uses a solver to calculate the action")

protected:
PolicySolver *solver_;
int interval_, episodes_;

public:
SolverPolicy() : solver_(NULL), interval_(1), episodes_(0) { }

// From Configurable
virtual void request(ConfigurationRequest *config);
virtual void configure(Configuration &config);
virtual void reconfigure(const Configuration &config);

// From Policy
virtual void act(double time, const Observation &in, Action *out);
virtual void act(const Observation &in, Action *out) const;
};

}

#endif /* GRL_SOLVER_POLICY_H_ */
9 changes: 9 additions & 0 deletions base/include/grl/solver.h
Expand Up @@ -29,6 +29,7 @@
#define GRL_SOLVER_H_

#include <grl/configurable.h>
#include <grl/policy.h>

namespace grl
{
Expand All @@ -49,6 +50,14 @@ class Solver : public Configurable
virtual bool resolve(double t, const Vector &xt) { return false; }
};

/// Solver that returns a policy
class PolicySolver : public Solver
{
public:
/// Returns policy that is adjusted by this solver.
virtual Policy *policy() = 0;
};

}

#endif /* GRL_SOLVER_H_ */
72 changes: 72 additions & 0 deletions base/src/policies/solver.cpp
@@ -0,0 +1,72 @@
/** \file solver.cpp
* \brief Solver policy source file.
*
* \author Wouter Caarls <wouter@caarls.org>
* \date 2018-11-14
*
* \copyright \verbatim
* Copyright (c) 2018, Wouter Caarls
* All rights reserved.
*
* This file is part of GRL, the Generic Reinforcement Learning library.
*
* GRL is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
* \endverbatim
*/

#include <grl/policies/solver.h>

using namespace grl;

REGISTER_CONFIGURABLE(SolverPolicy)

void SolverPolicy::request(ConfigurationRequest *config)
{
config->push_back(CRP("interval", "Episodes between successive solutions (0=asynchronous)", interval_, CRP::Configuration, 1, INT_MAX));
config->push_back(CRP("solver", "solver/policy", "Solver that calculates the policy", solver_));
}

void SolverPolicy::configure(Configuration &config)
{
solver_ = (PolicySolver*)config["solver"].ptr();
interval_ = config["interval"];

episodes_ = 0;
}

void SolverPolicy::reconfigure(const Configuration &config)
{
if (config.has("action") && config["action"].str() == "reset")
episodes_ = 0;
}

void SolverPolicy::act(double time, const Observation &in, Action *out)
{
if (time == 0.)
{
episodes_++;
if (interval_ && (episodes_%interval_)==0)
solver_->solve();
solver_->solve(in);
}
else
solver_->resolve(time, in);

return solver_->policy()->act(in, out);
}

void SolverPolicy::act(const Observation &in, Action *out) const
{
return solver_->policy()->act(in, out);
}

0 comments on commit 92785b6

Please sign in to comment.