/
main.go
69 lines (63 loc) · 1.28 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
package main
import (
"context"
"encoding/gob"
"io"
"log"
"os"
"github.com/owulveryck/lstm"
G "gorgonia.org/gorgonia"
)
func main() {
model := lstm.NewModel(18, 9, 512)
learnrate := 1e-3
l2reg := 1e-3
clipVal := float64(5)
solver := G.NewRMSPropSolver(G.WithLearnRate(learnrate), G.WithL2Reg(l2reg), G.WithClip(clipVal))
tset := newTictactoe()
pause := make(chan struct{})
infoChan, errc := model.Train(context.TODO(), tset, solver, pause)
iter := 1
for infos := range infoChan {
if iter%100 == 0 {
log.Printf("%v: %v", iter, infos)
}
if iter%500 == 0 {
// save the software 2.0
f, err := os.OpenFile("tictactoe.bin", os.O_RDWR|os.O_CREATE, 0755)
if err != nil {
log.Println(err)
}
enc := gob.NewEncoder(f)
err = enc.Encode(model)
if err != nil {
log.Println(err)
}
if err := f.Close(); err != nil {
log.Println(err)
}
}
iter++
}
err := <-errc
if err == io.EOF {
close(pause)
//return
}
if err != nil && err != io.EOF {
log.Fatal(err)
}
// save the software 2.0
f, err := os.OpenFile("tictactoe.bin", os.O_RDWR|os.O_CREATE, 0755)
if err != nil {
log.Println(err)
}
enc := gob.NewEncoder(f)
err = enc.Encode(model)
if err != nil {
log.Println(err)
}
if err := f.Close(); err != nil {
log.Println(err)
}
}