Permalink
Browse files

Add horribly broken SARSA code

  • Loading branch information...
1 parent 5e30a6b commit e622ffded2feb0b6050f38d1c0b9644a5d1222d7 @cmansley cmansley committed Nov 15, 2010
Showing with 151 additions and 2 deletions.
  1. +5 −2 Makefile
  2. +105 −0 sarsa.cc
  3. +41 −0 sarsa.hh
View
@@ -2,8 +2,8 @@ CFLAGS=-Wall -g #-pg
INCLUDES=-I/home/cmansley/boost_1_41_0 #-I /koko/rl3/cmansley/boost_1_41_0 -I/koko/rl3/cmansley/local/include
LIBS= #-L/koko/rl3/cmansley/local/lib
-plan: main.o ccl.o ss.o ip.o uct.o lander.o hoo.o node.o gaussian.o hoot.o mcplanner.o chopper.o di.o lqr.o ddi.o bicycle.o sars.o
- g++ $(CFLAGS) -o plan main.o ccl.o ss.o ip.o uct.o lander.o hoo.o node.o hoot.o mcplanner.o chopper.o di.o lqr.o ddi.o bicycle.o sars.o $(LIBS) -lgsl -lgslcblas -lm -lgflags
+plan: main.o ccl.o ss.o ip.o uct.o lander.o hoo.o node.o gaussian.o hoot.o mcplanner.o chopper.o di.o lqr.o ddi.o bicycle.o sars.o sarsa.o
+ g++ $(CFLAGS) -o plan main.o ccl.o ss.o ip.o uct.o lander.o hoo.o node.o hoot.o mcplanner.o chopper.o di.o lqr.o ddi.o bicycle.o sars.o sarsa.o $(LIBS) -lgsl -lgslcblas -lm -lgflags
temp: temp.o gaussian.o hoo.o node.o chopper.o ip.o ddi.o
g++ $(CFLAGS) -o temp temp.o gaussian.o hoo.o node.o ip.o chopper.o ddi.o -lgsl -lgslcblas -lm
@@ -62,5 +62,8 @@ bicycle.o: bicycle.cc bicycle.hh
sars.o: sars.cc sars.hh
g++ -c $(CFLAGS) sars.cc $(INCLUDES)
+sarsa.o: sarsa.cc sarsa.hh
+ g++ -c $(CFLAGS) sarsa.cc $(INCLUDES)
+
clean:
rm -f *~ *.o plan temp
View
105 sarsa.cc
@@ -0,0 +1,105 @@
+/*
+ * Code by Chris Mansley
+ */
+#include <iostream>
+#include <iterator>
+#include <cmath>
+
+#include "sarsa.hh"
+
+
+SARSA::SARSA(Domain *d, Chopper *c, double epsilon) : Planner(d, c, epsilon)
+{
+ alpha = 0.1;
+}
+
+/*
+ *
+ */
+void SARSA::initialize()
+{
+ /* Locally store values and compute vmax */
+ rmax = domain->getRmax();
+ rmin = domain->getRmin();
+ gamma = domain->getDiscountFactor();
+ vmax = rmax/(1-gamma);
+}
+
+/*
+ *
+ */
+void SARSA::updateValue(int depth, SARS *sars, double qvalue)
+{
+ /* Create vector of ints for state, action and depth */
+ std::vector<int> sd = chopper->discretizeState(sars->s);
+ int da;
+ std::vector<int> sad = sd;
+ sd.push_back(depth);
+ sad.push_back(depth);
+ sad.push_back(0);
+
+ /* max Q step */
+ da = chopper->discretizeAction(selectAction(sars->s_prime, 0, false));
+ sad.back() = da;
+ double q_prime;
+ if(Q.find(sad) != Q.end()) {
+ q_prime = Q[sad];
+ } else {
+ /* Uninformative initialization */
+ q_prime = 0;
+ }
+
+ /* Find Q value */
+ da = chopper->discretizeAction(sars->a);
+ sad.back() = da;
+ double q;
+ if(Q.find(sad) != Q.end()) {
+ q = Q[sad];
+ } else {
+ /* Uninformative initialization */
+ q = 0;
+ }
+
+ /* SARSA rule */
+ Q[sad] = q + alpha*(sars->reward + gamma*q_prime - q);
+}
+
+/*
+ *
+ */
+Action SARSA::selectAction(State s, int depth, bool greedy)
+{
+ int k = chopper->getNumDiscreteActions();
+
+ /* Create vector of ints for state, action and depth */
+ std::vector<int> sd = chopper->discretizeState(s);
+ std::vector<int> sad = sd;
+ sd.push_back(depth);
+ sad.push_back(depth);
+ sad.push_back(0); // action slot
+
+ /* Grab the Q-value for this state action */
+ int nsd_temp = Nsd[sd];
+ double c;
+ std::vector<double> qtemp;
+ for(int action=0; action < k; action++) {
+ sad.back() = action;
+ if(Q.find(sad) != Q.end()) {
+ /* Store Q-value */
+ qtemp.push_back(Q[sad]);
+ } else {
+ qtemp.push_back(0);
+ }
+ }
+
+ /* Create max action or random if there are more than one */
+ int discreteAction;
+
+ /* Grab max action */
+ std::vector<double>::const_iterator largest = max_element(qtemp.begin(), qtemp.end());
+ discreteAction = largest - qtemp.begin();
+
+ Action a = chopper->continuousAction(discreteAction);
+
+ return a;
+}
View
@@ -0,0 +1,41 @@
+/*
+ * Code by Chris Mansley
+ */
+
+#ifndef SARSA_HH
+#define SARSA_HH
+
+/* Definition dependencies */
+#include <boost/unordered_map.hpp>
+
+#include "planner.hh"
+#include "sars.hh"
+
+class SARSA : public Planner
+{
+public:
+ /** Constructor */
+ SARSA(Domain *d, Chopper *c, double epsilon);
+
+ /** Destructor */
+ ~SARSA( ) { }
+
+ /** Initialize the planner */
+ void initialize();
+
+ /** Update SARSA with a SARSA */
+ void updateValue(int depth, SARS *sars, double qvalue);
+
+ /** Query action */
+ Action plan(State s);
+
+protected:
+ /** Learning rate */
+ double alpha;
+
+ /** Algorithm Data Structures */
+ boost::unordered_map<std::vector<int>, double> Q;
+
+};
+
+#endif // SARSA_HH

0 comments on commit e622ffd

Please sign in to comment.