/
main.go
129 lines (103 loc) · 2.62 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package main
import (
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"runtime/pprof"
. "gorgonia.org/gorgonia"
"gorgonia.org/tensor"
)
var cpuprofile = flag.String("cpuprofile", "", "write cpu profile to file")
var memprofile = flag.String("memprofile", "", "write memory profile to this file")
// TODO/Questions:
// Input:
// extend vocab.go to import txt
// - got ur dict got ur vocab got ur vocabindex
// - replace sentencesRaw
// - slice into chars. map to []rune
// - produce paired input/output examples (input, input + 1 across corpus)
// - 1HV the lot? sparsity meh?
// Model:
// - why keySize = len of keys +2?
// - definition of ValueGrad?
// Main:
// - keys/durations/mOut?
// - why pointer for trainiter?
const (
embeddingSize = 20
maxOut = 30
// gradient update stuff
l2reg = 0.000001
learnrate = 0.01
clipVal = 5.0
)
var trainiter = flag.Int("iter", 5, "How many iterations to train")
var inputFile = flag.String("filename", "shakespeare.txt", "Filename of text to train on")
// various global variable inits
var epochSize = -1
var inputSize = -1
var outputSize = -1
// const corpus string = "shakespeare.txt"
// const corpus string = `the cat sat on the mat
// hello world
// wild stalyns
var corpus string
func init() {
buf, err := ioutil.ReadFile("shakespeare.txt")
if err != nil {
panic(err)
}
corpus = string(buf)
}
var dt tensor.Dtype = tensor.Float32
// type trainingPair string
// type pair struct {
// t string
// tplusone string
// }
// type trainingPair struct {
// in, out []message
// }
// WHERE TO 1HV THE INPUITS?????
// func OneHotVector(id, classes int, t tensor.Dtype, opts ...NodeConsOpt) *Node {
func main() {
flag.Parse()
if *cpuprofile != "" {
f, err := os.Create(*cpuprofile)
if err != nil {
log.Fatal(err)
}
pprof.StartCPUProfile(f)
defer pprof.StopCPUProfile()
}
// f, err := os.Create("trace.out")
// if err != nil {
// panic(err)
// }
// defer f.Close()
// err = trace.Start(f)
// if err != nil {
// panic(err)
// }
// defer trace.Stop()
hiddenSize := 100
s2s := NewS2S(hiddenSize, embeddingSize, vocab)
solver := NewRMSPropSolver(WithLearnRate(learnrate), WithL2Reg(l2reg), WithClip(clipVal), WithBatchSize(float64(len(sentences))))
for k, v := range vocabIndex {
log.Printf("%q %v", k, v)
}
p, h, w, err := Heatmap(s2s.decoder.Value().(*tensor.Dense))
p.Save(w, h, "embn0.png")
if err := train(s2s, 300, solver, sentences); err != nil {
panic(err)
}
out, err := s2s.predict([]rune(corpus))
if err != nil {
panic(err)
}
fmt.Printf("OUT %q\n", out)
p, h, w, err = Heatmap(s2s.decoder.Value().(*tensor.Dense))
p.Save(w, h, "embn.png")
}