forked from sql-machine-learning/sqlflow
/
codegen.go
141 lines (125 loc) · 3.79 KB
/
codegen.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package sql
import (
"text/template"
)
var fieldTypeFeatureType = map[string]string{"float": "numeric_column"}
type columnType struct {
Name string
Type string
}
type connectionConfig struct {
User string
Password string
Host string
Database string
WorkDir string
}
type TemplateFiller struct {
Train bool
// Model Config
StandardSelect string
Estimator string
Attrs map[string]string
Save string
// Data Config
X []columnType
Y columnType
// Connection Config
connectionConfig
}
func NewTemplateFiller(pr *extendedSelect, fts fieldTypes, cfg connectionConfig) (*TemplateFiller, bool) {
r := &TemplateFiller{
Train: pr.train,
StandardSelect: pr.standardSelect.String(),
Estimator: pr.estimator,
Attrs: make(map[string]string),
Save: pr.save}
for k, v := range pr.attrs {
r.Attrs[k] = v.String()
}
for _, c := range pr.columns {
typ, ok := fts.get(c.val)
if !ok {
return nil, ok
}
ct := columnType{Name: c.val, Type: fieldTypeFeatureType[typ]}
r.X = append(r.X, ct)
}
typ, ok := fts.get(pr.label)
if !ok {
return nil, ok
}
r.Y = columnType{Name: pr.label, Type: fieldTypeFeatureType[typ]}
r.connectionConfig = cfg
return r, true
}
const codegen_template_text = `
import tensorflow as tf
import sys, json, os
import mysql.connector
` +
// TODO(tonyyang-svail): remove hard coded BATCHSIZE, STEP
`
BATCHSIZE = 1
STEP = 1000
WORK_DIR = "{{.WorkDir}}"
USER = "{{.User}}"
PASSWORD = "{{.Password}}"
HOST = "{{.Host}}"
DATABASE = "{{.Database}}"
db = mysql.connector.connect(user=USER, passwd=PASSWORD, host=HOST, database=DATABASE)
cursor = db.cursor()
cursor.execute("""{{.StandardSelect}}""")
field_names = [i[0] for i in cursor.description]
columns = map(list, zip(*cursor.fetchall()))
feature_columns = [{{range .X}}tf.feature_column.{{.Type}}(key="{{.Name}}"),
{{end}}]
feature_column_names = [{{range .X}}"{{.Name}}",
{{end}}]
X = {name: columns[field_names.index(name)] for name in feature_column_names}
Y = columns[field_names.index("{{.Y.Name}}")]
{{if .Train}}
classifier = tf.estimator.{{.Estimator}}(
feature_columns=feature_columns,
hidden_units={{index .Attrs "hidden_units"}},
n_classes={{index .Attrs "n_classes"}},
model_dir=os.path.join(WORK_DIR, "{{.Save}}"))
def train_input_fn(features, labels, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
return dataset
classifier.train(
input_fn=lambda:train_input_fn(X, Y, BATCHSIZE),
steps=STEP)
` +
// TODO(tonyyang-svail): avoid JSON
// print("Dumping sql parsed data ...")
// with open(os.path.join(WORK_DIR, "{{.Save}}", SQL_PARSING_RESULT_FILE), "w") as f:
// f.write("""{{.JSON}}""")
`
print("Done training")
{{- else}}
` +
// TODO(tonyyang-svail): avoid JSON
// with open(os.path.join(WORK_DIR, "{{.InferClause.Model}}", SQL_PARSING_RESULT_FILE)) as f:
// desc = json.load(f)
`
def eval_input_fn(features, labels, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
dataset = dataset.batch(batch_size)
return dataset
` +
// TODO(tonyyang-svail): remove hard coded DNNClassifier
`
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=eval(desc["TrainClause"]["Attrs"]["hidden_units"]),
n_classes=eval(desc["TrainClause"]["Attrs"]["n_classes"]),
model_dir=os.path.join(WORK_DIR, "{{.InferClause.Model}}"))
eval_result = classifier.evaluate(
input_fn=lambda:eval_input_fn(X, Y, BATCHSIZE),
steps=STEP)
print("\nTest set accuracy: {accuracy:0.5f}\n".format(**eval_result))
{{- end}}
`
var codegen_template *template.Template = template.Must(template.New("codegen").Parse(codegen_template_text))