-
Notifications
You must be signed in to change notification settings - Fork 0
/
level.go
36 lines (32 loc) · 800 Bytes
/
level.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
package nrpa
import (
"alda/utils"
)
// Data structure for each NRPA recursive call (level)
type Level struct {
Policy [][]float64
BestRollout *Rollout
LegalMovesPerStep [][]int
}
// Adapt the level policy by increasing the probability of the current BestRollout
func (l *Level) AdaptPolicy(policyTmp [][]float64) {
var k int
u := 0
r := l.BestRollout
utils.CopyPolicy(l.Policy, policyTmp) //copy level policy in a temporal policy copy
for step := range l.LegalMovesPerStep {
v := r.Tour[step+1]
moves := l.LegalMovesPerStep[step]
l.Policy[u][v] += alpha
z := 0.0
for m := range moves {
k = moves[m]
z += utils.Exp(policyTmp[u][k])
}
for m := range moves {
k = moves[m]
l.Policy[u][k] -= alpha * utils.Exp(policyTmp[u][k]) / z
}
u = v
}
}