forked from schemalex/schemalex
-
Notifications
You must be signed in to change notification settings - Fork 0
/
source.go
245 lines (212 loc) · 6.97 KB
/
source.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
package schemalex
import (
"bytes"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"database/sql"
"fmt"
"io"
"io/ioutil"
"net/url"
"os"
"os/exec"
"strconv"
"strings"
"github.com/go-sql-driver/mysql"
"github.com/schemalex/schemalex/internal/errors"
)
// SchemaSource is the interface used for objects that provide us with
// a MySQL database schema to work with.
type SchemaSource interface {
// WriteSchema is responsible for doing whatever necessary to retrieve
// the database schema and write to the given io.Writer
WriteSchema(io.Writer) error
}
type readerSource struct {
src io.Reader
}
type mysqlSource string
type localFileSource string
type localGitSource struct {
dir string
file string
commitish string
}
// NewSchemaSource creates a SchemaSource based on the given URI.
// Currently "-" (for stdin), "local-git://...", "mysql://...", and
// "file://..." are supported. A string that does not match any of
// the above patterns and has no scheme part is treated as a local file.
func NewSchemaSource(uri string) (SchemaSource, error) {
// "-" is a special source, denoting stdin.
if uri == "-" {
return NewReaderSource(os.Stdin), nil
}
u, err := url.Parse(uri)
if err != nil {
return nil, errors.Wrap(err, `failed to parse uri`)
}
switch strings.ToLower(u.Scheme) {
case "local-git":
// local-git:///path/to/dir?file=foo&commitish=bar
q := u.Query()
return NewLocalGitSource(u.Path, q.Get("file"), q.Get("commitish")), nil
case "mysql":
// Treat the argument as a DSN for mysql.
// DSN is everything after "mysql://", so let's be lazy
// and use everything after the second slash
return NewMySQLSource(uri[8:]), nil
case "file", "":
// Eh, no remote host, please
if u.Host != "" && u.Host != "localhost" {
return nil, errors.Wrap(err, `remote hosts for file:// sources are not supported`)
}
return NewLocalFileSource(u.Path), nil
}
return nil, errors.New("invalid source")
}
// NewReaderSource creates a SchemaSource whose contents are read from the
// given io.Reader.
func NewReaderSource(src io.Reader) SchemaSource {
return &readerSource{src: src}
}
// NewMySQLSource creates a SchemaSource whose contents are derived by
// accessing the specified MySQL instance.
//
// MySQL sources respect extra parameters "ssl-ca", "ssl-cert", and
// "ssl-secret" (which all should point to local file names) when
// the "tls" parameter is set to some boolean true value. In this
// case, we register the given tls configuration using those values
// automatically.
//
// Please note that the "tls" parameter MUST BE A BOOLEAN. Otherwise
// we expect that you have already registered your tls configuration
// manually, and that you gave us the name of that configuration
func NewMySQLSource(s string) SchemaSource {
return mysqlSource(s)
}
// NewLocalFileSource creates a SchemaSource whose contents are derived from
// the given local file
func NewLocalFileSource(s string) SchemaSource {
return localFileSource(s)
}
// NewLocalGitSource creates a SchemaSource whose contents are derived from
// the given file at the given commit ID in a git repository.
func NewLocalGitSource(gitDir, file, commitish string) SchemaSource {
return &localGitSource{
dir: gitDir,
file: file,
commitish: commitish,
}
}
func (s *readerSource) WriteSchema(dst io.Writer) error {
if _, err := io.Copy(dst, s.src); err != nil {
return errors.Wrap(err, `failed to write schema to dst`)
}
return nil
}
// MySQLConfig creates a *mysql.Config struct from the given DSN.
func (s mysqlSource) MySQLConfig() (*mysql.Config, error) {
cfg, err := mysql.ParseDSN(string(s))
if err != nil {
return nil, errors.Wrap(err, `failed to parse DSN`)
}
// because _I_ need support for tls, I'm going to handle setting up
// the tls stuff, by using
// tls=true&ssl-ca=file=...&ssl-cert=...&ssql-secret=...
if v, err := strconv.ParseBool(cfg.TLSConfig); err == nil && v {
sslCa := cfg.Params["ssl-ca"]
sslCert := cfg.Params["ssl-cert"]
sslSecret := cfg.Params["ssl-secret"]
if sslCa == "" || sslCert == "" || sslSecret == "" {
return nil, errors.New(`to enable tls, you must provide ssl-ca, ssl-cert, and ssl-secret parameters to the DSN`)
}
// When comparing two mysql schemas against eachother, we will have
// multiple calls to RegisterTLSConfig, and in that case we need
// unique names for both.
//
// Here, we do the poor man's UUID, and create a unique name
b := make([]byte, 16)
rand.Reader.Read(b)
b[6] = (b[6] & 0x0F) | 0x40
b[8] = (b[8] &^ 0x40) | 0x80
tlsName := fmt.Sprintf("custom-tls-%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:])
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(sslCa)
if err != nil {
return nil, errors.Wrap(err, `failed to read ssl-ca file`)
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return nil, errors.New(`failed to append ssl-ca PEM to cert pool`)
}
certs, err := tls.LoadX509KeyPair(sslCert, sslSecret)
if err != nil {
return nil, errors.Wrap(err, `failed to load X509 key pair`)
}
mysql.RegisterTLSConfig(tlsName, &tls.Config{
RootCAs: rootCertPool,
Certificates: []tls.Certificate{certs},
})
cfg.TLSConfig = tlsName
}
return cfg, nil
}
func (s mysqlSource) open() (*sql.DB, error) {
// attempt to open connection to mysql
cfg, err := s.MySQLConfig()
if err != nil {
return nil, errors.Wrap(err, `failed to create MySQL config from source spec`)
}
return sql.Open("mysql", cfg.FormatDSN())
}
func (s localFileSource) WriteSchema(dst io.Writer) error {
f, err := os.Open(string(s))
if err != nil {
return errors.Wrapf(err, `failed to open local file %s`, s)
}
defer f.Close()
if _, err := io.Copy(dst, f); err != nil {
return errors.Wrap(err, `failed to copy file contents to dst`)
}
return nil
}
func (s mysqlSource) WriteSchema(dst io.Writer) error {
db, err := s.open()
if err != nil {
return errors.Wrap(err, `failed to open connection to database`)
}
defer db.Close()
tableRows, err := db.Query("SHOW TABLES")
if err != nil {
return errors.Wrap(err, `failed to execute 'SHOW TABLES'`)
}
defer tableRows.Close()
var table string
var tableSchema string
var buf bytes.Buffer
for tableRows.Next() {
if err = tableRows.Scan(&table); err != nil {
return errors.Wrap(err, `failed to scan tables`)
}
if err = db.QueryRow("SHOW CREATE TABLE `"+table+"`").Scan(&table, &tableSchema); err != nil {
return errors.Wrapf(err, `failed to execute 'SHOW CREATE TABLE "%s"'`, table)
}
if buf.Len() > 0 {
buf.WriteString("\n\n")
}
// TODO remove dynamic info. ex) AUTO_INCREMENT,PARTITION
buf.WriteString(tableSchema)
buf.WriteByte(';')
}
return NewReaderSource(&buf).WriteSchema(dst)
}
func (s localGitSource) WriteSchema(dst io.Writer) error {
var out bytes.Buffer
cmd := exec.Command("git", "show", fmt.Sprintf("%s:%s", s.commitish, s.file))
cmd.Stdout = &out
cmd.Dir = s.dir
if err := cmd.Run(); err != nil {
return errors.Wrapf(err, `failed to run git command: %s`, cmd.Args)
}
return NewReaderSource(&out).WriteSchema(dst)
}