-
Notifications
You must be signed in to change notification settings - Fork 0
/
reinforcement_agent.h
51 lines (40 loc) · 1.15 KB
/
reinforcement_agent.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
/*
Copyright (C) 2017 Meritxell Jordana
Copyright (C) 2017 Marc Sanchez
*/
#ifndef __REINFORCEMENT_AGENT_H
#define __REINFORCEMENT_AGENT_H
#include <map>
#include "map.h"
#include "strategy.h"
typedef struct QValuesKey {
const Map state;
const Direction action;
QValuesKey(const Map state, const Direction action);
bool operator<(const QValuesKey &qvk) const;
} QValuesKey;
class ReinforcementAgent : public Strategy {
int numTraining;
int episodes;
double accumTrainRewards;
double accumTestRewards;
Map *lastState;
Direction lastAction;
double episodeRewards;
virtual void update(QValuesKey &key, Map &nextState, double reward) = 0;
void observeTransition(QValuesKey key, Map &nextState, double deltaReward);
void startEpisode();
void stopEpisode();
protected:
double epsilon;
double alpha;
double discount;
void doAction(Map &state, Direction action);
void observationFunction(Map &state);
public:
ReinforcementAgent(Map *map, int numTraining, double epsilon, double alpha,
double gamma);
void final();
void registerInitialState(Map *state);
};
#endif