/
CreateSurprise.java
119 lines (110 loc) · 5.2 KB
/
CreateSurprise.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package rlpark.example.irobot.surprise;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import rlpark.plugin.irobot.data.CreateAction;
import rlpark.plugin.irobot.data.IRobotSongs;
import rlpark.plugin.irobot.robots.CreateRobot;
import rlpark.plugin.irobot.robots.IRobotEnvironment;
import rlpark.plugin.rltoys.algorithms.functions.states.AgentState;
import rlpark.plugin.rltoys.algorithms.predictions.td.GTDLambda;
import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD;
import rlpark.plugin.rltoys.algorithms.predictions.td.TDLambda;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.envio.observations.Legend;
import rlpark.plugin.rltoys.envio.observations.ObsFilter;
import rlpark.plugin.rltoys.envio.observations.Observation;
import rlpark.plugin.rltoys.envio.policy.SingleActionPolicy;
import rlpark.plugin.rltoys.envio.policy.Policies;
import rlpark.plugin.rltoys.envio.policy.Policy;
import rlpark.plugin.rltoys.horde.Horde;
import rlpark.plugin.rltoys.horde.Surprise;
import rlpark.plugin.rltoys.horde.demons.Demon;
import rlpark.plugin.rltoys.horde.demons.PredictionDemon;
import rlpark.plugin.rltoys.horde.demons.PredictionOffPolicyDemon;
import rlpark.plugin.rltoys.horde.functions.HordeUpdatable;
import rlpark.plugin.rltoys.horde.functions.RewardFunction;
import rlpark.plugin.rltoys.horde.functions.RewardObservationFunction;
import rlpark.plugin.rltoys.math.vector.RealVector;
import zephyr.plugin.core.api.Zephyr;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;
import zephyr.plugin.core.api.synchronization.Clock;
@SuppressWarnings("restriction")
@Monitor
public class CreateSurprise implements Runnable {
static final private int SurpriseTrackingSpeed = 100;
static final private Action[] Actions = new Action[] { CreateAction.DontMove, CreateAction.SpinLeft,
CreateAction.SpinRight, CreateAction.Forward };
static final private String[] PredictedLabels = new String[] { "WheelDrop", "Bump", "WheelOverCurrent", "ICOmni",
"DriveDistance", "DriveAngle", "BatteryCurrent", "BatteryCharge", "WallSignal", "CliffSignal",
"ConnectedHomeBase", "OIMode", "WheelRequested" };
static final private double[] Gammas = new double[] { .0, 0.9, 0.99 };
static final private Policy[] TargetPolicies = new Policy[] { new SingleActionPolicy(CreateAction.SpinLeft),
new SingleActionPolicy(CreateAction.Forward) };
static final private double Lambda = .7;
final private IRobotEnvironment robot = new CreateRobot();
final private Clock clock = new Clock("Surprise");
final private Horde horde;
final private Surprise surprise;
private final AgentState agentState;
private final Policy robotBehaviour;
private RealVector x_t;
private Action a_t;
public CreateSurprise() {
agentState = new RobotState();
robotBehaviour = new RobotBehaviour(new Random(0), .25, Actions);
horde = createHorde();
surprise = new Surprise(horde.demons(), SurpriseTrackingSpeed);
Zephyr.advertise(clock, this);
}
private Horde createHorde() {
List<RewardFunction> rewardFunctions = createRewardFunctions();
List<Demon> demons = new ArrayList<Demon>();
for (RewardFunction rewardFunction : rewardFunctions) {
for (double gamma : Gammas) {
demons.add(newNextingPredictionDemon(rewardFunction, gamma));
for (Policy targetPolicy : TargetPolicies)
demons.add(newOffPolicyPredictionDemon(rewardFunction, gamma, targetPolicy));
}
}
Horde horde = new Horde();
horde.demons().addAll(demons);
for (RewardFunction rewardFunction : rewardFunctions)
horde.addBeforeFunction((HordeUpdatable) rewardFunction);
return horde;
}
private PredictionOffPolicyDemon newOffPolicyPredictionDemon(RewardFunction rewardFunction, double gamma,
Policy targetPolicy) {
GTDLambda gtd = new GTDLambda(Lambda, gamma, .1 / agentState.stateNorm(), 0.0001 / agentState.stateNorm(),
agentState.stateSize());
return new PredictionOffPolicyDemon(targetPolicy, robotBehaviour, gtd, rewardFunction);
}
private PredictionDemon newNextingPredictionDemon(RewardFunction rewardFunction, double gamma) {
OnPolicyTD td = new TDLambda(Lambda, gamma, .1 / agentState.stateNorm(), agentState.stateSize());
return new PredictionDemon(rewardFunction, td);
}
private List<RewardFunction> createRewardFunctions() {
ArrayList<RewardFunction> rewardFunctions = new ArrayList<RewardFunction>();
Legend legend = robot.legend();
ObsFilter filter = new ObsFilter(legend, PredictedLabels);
for (String label : filter.legend().getLabels())
rewardFunctions.add(new RewardObservationFunction(legend, label));
return rewardFunctions;
}
@Override
public void run() {
robot.fullMode();
while (clock.tick()) {
Observation o_tp1 = robot.waitNewRawObs();
RealVector x_tp1 = agentState.update(a_t, o_tp1);
horde.update(o_tp1, x_t, a_t, x_tp1);
double surpriseMeasure = surprise.updateSurpriseMeasure();
if (surpriseMeasure > 8.0)
robot.playSong(IRobotSongs.composeHappySong());
Action a_tp1 = Policies.decide(robotBehaviour, x_tp1);
robot.sendAction((CreateAction) a_tp1);
x_t = x_tp1;
a_t = a_tp1;
}
}
}