-
Notifications
You must be signed in to change notification settings - Fork 11
/
True Online Sarsa(λ).kt
67 lines (65 loc) · 1.86 KB
/
True Online Sarsa(λ).kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
package lab.mars.rl.algo.eligibility_trace.control
import lab.mars.rl.algo.EpisodeListener
import lab.mars.rl.algo.StepListener
import lab.mars.rl.model.MDP
import lab.mars.rl.model.Policy
import lab.mars.rl.model.impl.func.LinearFunc
import lab.mars.rl.model.isNotTerminal
import lab.mars.rl.model.log
import lab.mars.rl.util.log.debug
import lab.mars.rl.util.matrix.Matrix
import lab.mars.rl.util.matrix.MatrixSpec
import lab.mars.rl.util.matrix.minus
import lab.mars.rl.util.matrix.times
fun <E> MDP.`True Online Sarsa(λ)`(
Qfunc: LinearFunc<E>,
π: Policy,
λ: Double,
α: Double,
episodes: Int,
z_maker: (Int, Int) -> MatrixSpec = { m, n -> Matrix(m, n) },
maxStep: Int = Int.MAX_VALUE,
episodeListener: EpisodeListener = { _, _, _, _ -> },
stepListener: StepListener = { _, _, _, _ -> }) {
val X = Qfunc.x
val w = Qfunc.w
val d = w.size
val z = z_maker(d, 1)
for (episode in 1..episodes) {
log.debug { "$episode/$episodes" }
var step = 0
var s = started()
var a = π(s)
var x = X(s, a)
z.zero()
var Q_old = 0.0
var G = 0.0
var γn = 1.0
while (true) {
z `=` (γ * λ * z + (1.0 - α * γ * λ * (z `T*` x)) * x)
val (s_next, reward) = a.sample()
γn *= γ
G += γn * reward
s = s_next
val Q = (w `T*` x).toScalar
var δ = reward - Q
if (s_next.isNotTerminal) {
val a_next = π(s_next)
val `x'` = X(s_next, a_next)
val `Q'` = (w `T*` `x'`).toScalar
δ += γ * `Q'`
w += α * (δ + Q - Q_old) * z - α * (Q - Q_old) * x
Q_old = `Q'`
x = `x'`
a = a_next
} else {
w += α * (δ + Q - Q_old) * z - α * (Q - Q_old) * x
break
}
step++
stepListener(episode, step, s_next, a)
if (step >= maxStep) break
}
episodeListener(episode, step, s, G)
}
}