-
Notifications
You must be signed in to change notification settings - Fork 11
/
PolicyIteration.kt
88 lines (80 loc) · 2.47 KB
/
PolicyIteration.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package lab.mars.rl.algo.dp
import lab.mars.rl.algo.Q_from_V
import lab.mars.rl.algo.V_from_Q
import lab.mars.rl.model.impl.mdp.IndexedMDP
import lab.mars.rl.model.impl.mdp.IndexedPolicy
import lab.mars.rl.model.impl.mdp.OptimalSolution
import lab.mars.rl.model.isNotTerminal
import lab.mars.rl.model.log
import lab.mars.rl.util.collection.filter
import lab.mars.rl.util.collection.fork
import lab.mars.rl.util.log.debug
import lab.mars.rl.util.math.argmax
import lab.mars.rl.util.math.Σ
import lab.mars.rl.util.tuples.tuple3
import org.apache.commons.math3.util.FastMath.abs
import org.apache.commons.math3.util.FastMath.max
/**
* <p>
* Created on 2017-09-05.
* </p>
*
* @author wumo
*/
val θ = 1e-6
fun IndexedMDP.`Policy Iteration V`(): OptimalSolution {
val V = VFunc { 0.0 }
val π = IndexedPolicy(QFunc { 1.0 })
val Q = QFunc { 0.0 }
do {
//Policy Evaluation
do {
var Δ = 0.0
for (s in states.filter { it.isNotTerminal }) {
val v = V[s]
V[s] = Σ(π(s).possibles) { probability * (reward + γ * V[next]) }
Δ = max(Δ, abs(v - V[s]))
}
log.debug { "Δ=$Δ" }
} while (Δ >= θ)
//Policy Improvement
var `policy-stable` = true
for (s in states.filter { it.isNotTerminal }) {
val `old-action` = π(s)
val `new-action` = argmax(s.actions) { Σ(possibles) { probability * (reward + γ * V[next]) } }
π[s] = `new-action`
if (`old-action` !== `new-action`) `policy-stable` = false
}
} while (!`policy-stable`)
val result = tuple3(π, V, Q)
Q_from_V(γ, states, result)
return result
}
fun IndexedMDP.`Policy Iteration Q`(): OptimalSolution {
val V = VFunc { 0.0 }
val π = IndexedPolicy(QFunc { 1.0 })
val Q = QFunc { 0.0 }
do {
//Policy Evaluation
do {
var Δ = 0.0
for ((s, a) in states.fork { it.actions }) {
val q = Q[s, a]
Q[s, a] = Σ(a.possibles) { probability * (reward + γ * if (next.actions.any()) Q[next, π(next)] else 0.0) }
Δ = max(Δ, abs(q - Q[s, a]))
}
log.debug { "Δ=$Δ" }
} while (Δ >= θ)
//Policy Improvement
var `policy-stable` = true
for (s in states.filter { it.isNotTerminal }) {
val `old-action` = π(s)
val `new-action` = argmax(s.actions) { Q[s, it] }
π[s] = `new-action`
if (`old-action` !== `new-action`) `policy-stable` = false
}
} while (!`policy-stable`)
val result = tuple3(π, V, Q)
V_from_Q(states, result)
return result
}