-
Notifications
You must be signed in to change notification settings - Fork 568
/
db.go
145 lines (130 loc) · 3.72 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package pachsql
import (
"context"
"net"
"strconv"
"strings"
"sync"
"github.com/go-sql-driver/mysql"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/jmoiron/sqlx"
"github.com/pachyderm/pachyderm/v2/src/internal/errors"
"go.uber.org/zap"
)
const (
ProtocolPostgres = "postgres"
ProtocolMySQL = "mysql"
)
const (
listTablesQuery = `
SELECT schemaname, tablename
FROM pg_catalog.pg_tables
WHERE schemaname != 'pg_catalog'
AND schemaname != 'information_schema'
ORDER BY schemaname, tablename;
`
)
var fixMysqlLoggerOnce sync.Once
// DB is an alias for sqlx.DB which is the standard database type used throughout the project
type DB = sqlx.DB
// Tx is an alias for sqlx.Tx which is the standard transaction type used throughout the project
type Tx = sqlx.Tx
// Stmt is an alias for sqlx.Stmt which is the standard prepared statement type used throught the project
type Stmt = sqlx.Stmt
// SchemaTable stores a given table's name and schema.
type SchemaTable struct {
SchemaName string `json:"schemaname"`
TableName string `json:"tablename"`
}
// RowMap is an alias for map[string]interface{} which is the type used by sqlx.MapScan()
type RowMap = map[string]interface{}
// OpenURL returns a database connection pool to the database specified by u
// If password != "" then it will be used for authentication.
// This function does not confirm that the database is reachable; callers may be interested in pachsql.DB.Ping()
func OpenURL(u URL, password string) (*DB, error) {
var err error
var driver string
var dsn string
switch u.Protocol {
case ProtocolPostgres, "postgresql":
driver = "pgx"
dsn, err = postgresDSN(u, password)
case ProtocolMySQL:
driver = "mysql"
dsn, err = mySQLDSN(u, password)
fixMysqlLoggerOnce.Do(func() {
l := zap.NewStdLog(zap.L().Named("mysql"))
l.Println("enabled global mysql logger")
mysql.SetLogger(l) //nolint:errcheck
})
default:
return nil, errors.Errorf("database protocol %q not supported", u.Protocol)
}
if err != nil {
return nil, errors.Wrapf(err, "failed to generate DSN: %v", u)
}
res, err := sqlx.Open(driver, dsn)
return res, errors.EnsureStack(err)
}
// ListTables returns an array of SchemaTable structs that represent the tables.
func ListTables(ctx context.Context, db *DB) ([]SchemaTable, error) {
var tables []SchemaTable
if err := sqlx.SelectContext(ctx, db, &tables, listTablesQuery); err != nil {
return nil, errors.Wrap(err, "list tables")
}
return tables, nil
}
func postgresDSN(u URL, password string) (string, error) {
if u.Schema != "" {
return "", errors.New("postgres DSN should not contain schema name")
}
port := u.Port
if port == 0 {
port = 5432
}
fields := map[string]string{
"user": u.User,
"host": u.Host,
"port": strconv.Itoa(int(port)),
"dbname": u.Database,
}
if password != "" {
fields["password"] = password
}
for k, v := range u.Params {
fields[k] = v
}
var dsnParts []string
for k, v := range fields {
dsnParts = append(dsnParts, k+"="+v)
}
return strings.Join(dsnParts, " "), nil
}
func mySQLDSN(u URL, password string) (string, error) {
if u.Schema != "" {
return "", errors.New("mysql DSN should not contain schema name")
}
port := u.Port
if port == 0 {
port = 3306
}
params := copyParams(u.Params)
params["parseTime"] = "true"
config := mysql.Config{
User: u.User,
Passwd: password,
Net: "tcp",
Addr: net.JoinHostPort(u.Host, strconv.Itoa(int(port))),
DBName: u.Database,
Params: params,
AllowNativePasswords: true,
}
return config.FormatDSN(), nil
}
func copyParams(x map[string]string) map[string]string {
y := make(map[string]string, len(x))
for k, v := range x {
y[k] = v
}
return y
}