-
Notifications
You must be signed in to change notification settings - Fork 402
/
data.go
201 lines (169 loc) · 4.79 KB
/
data.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package dbschema
import (
"context"
"fmt"
"regexp"
"sort"
"strings"
"github.com/zeebo/errs"
)
// Data is the database content formatted as strings.
type Data struct {
Tables []*TableData
}
// TableData is content of a sql table.
type TableData struct {
Name string
Columns []string
Rows []RowData
}
// ColumnData is a value of a column within a row.
type ColumnData struct {
Column string
Value string
}
// String returns a string representation of the column.
func (c ColumnData) String() string {
return fmt.Sprintf("%s:%s", c.Column, c.Value)
}
// RowData is content of a single row.
type RowData []ColumnData
// Less returns true if one row is less than the other.
func (row RowData) Less(b RowData) bool {
n := len(row)
if len(b) < n {
n = len(b)
}
for k := 0; k < n; k++ {
if row[k].Value < b[k].Value {
return true
} else if row[k].Value > b[k].Value {
return false
}
}
return len(row) < len(b)
}
// AddTable adds a new table.
func (data *Data) AddTable(table *TableData) {
data.Tables = append(data.Tables, table)
}
// AddRow adds a new row.
func (table *TableData) AddRow(row RowData) error {
if len(row) != len(table.Columns) {
return errs.New("inconsistent row added to table")
}
for i, cdata := range row {
if cdata.Column != table.Columns[i] {
return errs.New("inconsistent row added to table")
}
}
table.Rows = append(table.Rows, row)
return nil
}
// FindTable finds a table by name.
func (data *Data) FindTable(tableName string) (*TableData, bool) {
for _, table := range data.Tables {
if table.Name == tableName {
return table, true
}
}
return nil, false
}
// Sort sorts all tables.
func (data *Data) Sort() {
for _, table := range data.Tables {
table.Sort()
}
}
// Sort sorts all rows.
func (table *TableData) Sort() {
sort.Slice(table.Rows, func(i, k int) bool {
return table.Rows[i].Less(table.Rows[k])
})
}
// Clone returns a clone of row data.
func (row RowData) Clone() RowData {
return append(RowData{}, row...)
}
// QueryData loads all data from tables.
func QueryData(ctx context.Context, db Queryer, schema *Schema, quoteColumn func(string) string) (*Data, error) {
data := &Data{}
for _, tableSchema := range schema.Tables {
if err := ValidateTableName(tableSchema.Name); err != nil {
return nil, err
}
columnNames := tableSchema.ColumnNames()
// quote column names
quotedColumns := make([]string, len(columnNames))
for i, columnName := range columnNames {
if err := ValidateColumnName(columnName); err != nil {
return nil, err
}
quotedColumns[i] = quoteColumn(columnName)
}
table := &TableData{
Name: tableSchema.Name,
Columns: columnNames,
}
data.AddTable(table)
/* #nosec G202 */ // The columns names and table name are validated above
query := `SELECT ` + strings.Join(quotedColumns, ", ") + ` FROM ` + table.Name
err := func() (err error) {
rows, err := db.QueryContext(ctx, query)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, rows.Close()) }()
row := make(RowData, len(columnNames))
rowargs := make([]interface{}, len(columnNames))
for i := range row {
row[i].Column = columnNames[i]
rowargs[i] = &row[i].Value
}
for rows.Next() {
err := rows.Scan(rowargs...)
if err != nil {
return err
}
if err := table.AddRow(row.Clone()); err != nil {
return err
}
}
return rows.Err()
}()
if err != nil {
return nil, err
}
}
data.Sort()
return data, nil
}
var columnNameWhiteList = regexp.MustCompile(`^(?:[a-zA-Z0-9_](?:-[a-zA-Z0-9_]|[a-zA-Z0-9_])?)+$`)
// ValidateColumnName checks column has at least 1 character and it's only
// formed by lower and upper case letters, numbers, underscores or dashes where
// dashes cannot be at the beginning of the end and not in a row.
func ValidateColumnName(column string) error {
if !columnNameWhiteList.MatchString(column) {
return errs.New(
"forbidden column name, it can only contains letters, numbers, underscores and dashes not in a row. Got: %s",
column,
)
}
return nil
}
var tableNameWhiteList = regexp.MustCompile(`^(?:[a-zA-Z0-9_](?:-[a-zA-Z0-9_]|[a-zA-Z0-9_])?)+(?:\.(?:[a-zA-Z0-9_](?:-[a-zA-Z0-9_]|[a-zA-Z0-9_])?)+)?$`)
// ValidateTableName checks table has at least 1 character and it's only
// formed by lower and upper case letters, numbers, underscores or dashes where
// dashes cannot be at the beginning of the end and not in a row.
// One dot is allowed for scoping tables in a schema (e.g. public.my_table).
func ValidateTableName(table string) error {
if !tableNameWhiteList.MatchString(table) {
return errs.New(
"forbidden table name, it can only contains letters, numbers, underscores and dashes not in a row. Got: %s",
table,
)
}
return nil
}