Skip to content

Commit d037637

Browse files
authored
*: don't use DSN to avoid some security problems (#38342)
1 parent 22b85b9 commit d037637

File tree

12 files changed

+144
-130
lines changed

12 files changed

+144
-130
lines changed

br/pkg/lightning/checkpoints/checkpoints.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,15 @@ func OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (DB, error) {
517517

518518
switch cfg.Checkpoint.Driver {
519519
case config.CheckpointDriverMySQL:
520-
db, err := common.ConnectMySQL(cfg.Checkpoint.DSN)
520+
var (
521+
db *sql.DB
522+
err error
523+
)
524+
if cfg.Checkpoint.MySQLParam != nil {
525+
db, err = cfg.Checkpoint.MySQLParam.Connect()
526+
} else {
527+
db, err = sql.Open("mysql", cfg.Checkpoint.DSN)
528+
}
521529
if err != nil {
522530
return nil, errors.Trace(err)
523531
}
@@ -546,7 +554,15 @@ func IsCheckpointsDBExists(ctx context.Context, cfg *config.Config) (bool, error
546554
}
547555
switch cfg.Checkpoint.Driver {
548556
case config.CheckpointDriverMySQL:
549-
db, err := sql.Open("mysql", cfg.Checkpoint.DSN)
557+
var (
558+
db *sql.DB
559+
err error
560+
)
561+
if cfg.Checkpoint.MySQLParam != nil {
562+
db, err = cfg.Checkpoint.MySQLParam.Connect()
563+
} else {
564+
db, err = sql.Open("mysql", cfg.Checkpoint.DSN)
565+
}
550566
if err != nil {
551567
return false, errors.Trace(err)
552568
}

br/pkg/lightning/common/util.go

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"io"
2424
"net"
2525
"net/http"
26-
"net/url"
2726
"os"
2827
"strconv"
2928
"strings"
@@ -58,28 +57,38 @@ type MySQLConnectParam struct {
5857
Vars map[string]string
5958
}
6059

61-
func (param *MySQLConnectParam) ToDSN() string {
62-
hostPort := net.JoinHostPort(param.Host, strconv.Itoa(param.Port))
63-
dsn := fmt.Sprintf("%s:%s@tcp(%s)/?charset=utf8mb4&sql_mode='%s'&maxAllowedPacket=%d&tls=%s",
64-
param.User, param.Password, hostPort,
65-
param.SQLMode, param.MaxAllowedPacket, param.TLS)
60+
func (param *MySQLConnectParam) ToDriverConfig() *mysql.Config {
61+
cfg := mysql.NewConfig()
62+
cfg.Params = make(map[string]string)
63+
64+
cfg.User = param.User
65+
cfg.Passwd = param.Password
66+
cfg.Net = "tcp"
67+
cfg.Addr = net.JoinHostPort(param.Host, strconv.Itoa(param.Port))
68+
cfg.Params["charset"] = "utf8mb4"
69+
cfg.Params["sql_mode"] = fmt.Sprintf("'%s'", param.SQLMode)
70+
cfg.MaxAllowedPacket = int(param.MaxAllowedPacket)
71+
cfg.TLSConfig = param.TLS
6672

6773
for k, v := range param.Vars {
68-
dsn += fmt.Sprintf("&%s='%s'", k, url.QueryEscape(v))
74+
cfg.Params[k] = fmt.Sprintf("'%s'", v)
6975
}
70-
71-
return dsn
76+
return cfg
7277
}
7378

74-
func tryConnectMySQL(dsn string) (*sql.DB, error) {
75-
driverName := "mysql"
76-
failpoint.Inject("MockMySQLDriver", func(val failpoint.Value) {
77-
driverName = val.(string)
79+
func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) {
80+
failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) {
81+
pwd := val.(string)
82+
if cfg.Passwd != pwd {
83+
failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"})
84+
}
85+
failpoint.Return(nil, nil)
7886
})
79-
db, err := sql.Open(driverName, dsn)
87+
c, err := mysql.NewConnector(cfg)
8088
if err != nil {
8189
return nil, errors.Trace(err)
8290
}
91+
db := sql.OpenDB(c)
8392
if err = db.Ping(); err != nil {
8493
_ = db.Close()
8594
return nil, errors.Trace(err)
@@ -89,13 +98,9 @@ func tryConnectMySQL(dsn string) (*sql.DB, error) {
8998

9099
// ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding,
91100
// we will try to connect MySQL with the base64 decoding of the password.
92-
func ConnectMySQL(dsn string) (*sql.DB, error) {
93-
cfg, err := mysql.ParseDSN(dsn)
94-
if err != nil {
95-
return nil, errors.Trace(err)
96-
}
101+
func ConnectMySQL(cfg *mysql.Config) (*sql.DB, error) {
97102
// Try plain password first.
98-
db, firstErr := tryConnectMySQL(dsn)
103+
db, firstErr := tryConnectMySQL(cfg)
99104
if firstErr == nil {
100105
return db, nil
101106
}
@@ -104,9 +109,9 @@ func ConnectMySQL(dsn string) (*sql.DB, error) {
104109
// If password is encoded by base64, try the decoded string as well.
105110
if password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd); decodeErr == nil && string(password) != cfg.Passwd {
106111
cfg.Passwd = string(password)
107-
db, err = tryConnectMySQL(cfg.FormatDSN())
112+
db2, err := tryConnectMySQL(cfg)
108113
if err == nil {
109-
return db, nil
114+
return db2, nil
110115
}
111116
}
112117
}
@@ -115,7 +120,7 @@ func ConnectMySQL(dsn string) (*sql.DB, error) {
115120
}
116121

117122
func (param *MySQLConnectParam) Connect() (*sql.DB, error) {
118-
db, err := ConnectMySQL(param.ToDSN())
123+
db, err := ConnectMySQL(param.ToDriverConfig())
119124
if err != nil {
120125
return nil, errors.Trace(err)
121126
}

br/pkg/lightning/common/util_test.go

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@ package common_test
1616

1717
import (
1818
"context"
19-
"database/sql"
20-
"database/sql/driver"
2119
"encoding/base64"
2220
"encoding/json"
2321
"fmt"
2422
"io"
25-
"math/rand"
2623
"net/http"
2724
"net/http/httptest"
28-
"strconv"
2925
"testing"
3026
"time"
3127

@@ -35,7 +31,6 @@ import (
3531
"github.com/pingcap/failpoint"
3632
"github.com/pingcap/tidb/br/pkg/lightning/common"
3733
"github.com/pingcap/tidb/br/pkg/lightning/log"
38-
tmysql "github.com/pingcap/tidb/errno"
3934
"github.com/stretchr/testify/assert"
4035
"github.com/stretchr/testify/require"
4136
)
@@ -85,66 +80,14 @@ func TestGetJSON(t *testing.T) {
8580
require.Regexp(t, ".*http status code != 200.*", err.Error())
8681
}
8782

88-
func TestToDSN(t *testing.T) {
89-
param := common.MySQLConnectParam{
90-
Host: "127.0.0.1",
91-
Port: 4000,
92-
User: "root",
93-
Password: "123456",
94-
SQLMode: "strict",
95-
MaxAllowedPacket: 1234,
96-
TLS: "cluster",
97-
Vars: map[string]string{
98-
"tidb_distsql_scan_concurrency": "1",
99-
},
100-
}
101-
require.Equal(t, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN())
102-
103-
param.Host = "::1"
104-
require.Equal(t, "root:123456@tcp([::1]:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN())
105-
}
106-
107-
type mockDriver struct {
108-
driver.Driver
109-
plainPsw string
110-
}
111-
112-
func (m *mockDriver) Open(dsn string) (driver.Conn, error) {
113-
cfg, err := mysql.ParseDSN(dsn)
114-
if err != nil {
115-
return nil, err
116-
}
117-
accessDenied := cfg.Passwd != m.plainPsw
118-
return &mockConn{accessDenied: accessDenied}, nil
119-
}
120-
121-
type mockConn struct {
122-
driver.Conn
123-
driver.Pinger
124-
accessDenied bool
125-
}
126-
127-
func (c *mockConn) Ping(ctx context.Context) error {
128-
if c.accessDenied {
129-
return &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}
130-
}
131-
return nil
132-
}
133-
134-
func (c *mockConn) Close() error {
135-
return nil
136-
}
137-
13883
func TestConnect(t *testing.T) {
13984
plainPsw := "dQAUoDiyb1ucWZk7"
140-
driverName := "mysql-mock-" + strconv.Itoa(rand.Int())
141-
sql.Register(driverName, &mockDriver{plainPsw: plainPsw})
14285

14386
require.NoError(t, failpoint.Enable(
144-
"github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver",
145-
fmt.Sprintf("return(\"%s\")", driverName)))
87+
"github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword",
88+
fmt.Sprintf("return(\"%s\")", plainPsw)))
14689
defer func() {
147-
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver"))
90+
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword"))
14891
}()
14992

15093
param := common.MySQLConnectParam{
@@ -155,13 +98,11 @@ func TestConnect(t *testing.T) {
15598
SQLMode: "strict",
15699
MaxAllowedPacket: 1234,
157100
}
158-
db, err := param.Connect()
101+
_, err := param.Connect()
159102
require.NoError(t, err)
160-
require.NoError(t, db.Close())
161103
param.Password = base64.StdEncoding.EncodeToString([]byte(plainPsw))
162-
db, err = param.Connect()
104+
_, err = param.Connect()
163105
require.NoError(t, err)
164-
require.NoError(t, db.Close())
165106
}
166107

167108
func TestIsContextCanceledError(t *testing.T) {

br/pkg/lightning/config/config.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -553,11 +553,12 @@ type TikvImporter struct {
553553
}
554554

555555
type Checkpoint struct {
556-
Schema string `toml:"schema" json:"schema"`
557-
DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON.
558-
Driver string `toml:"driver" json:"driver"`
559-
Enable bool `toml:"enable" json:"enable"`
560-
KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"`
556+
Schema string `toml:"schema" json:"schema"`
557+
DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON.
558+
MySQLParam *common.MySQLConnectParam `toml:"-" json:"-"` // For some security reason, we use MySQLParam instead of DSN.
559+
Driver string `toml:"driver" json:"driver"`
560+
Enable bool `toml:"enable" json:"enable"`
561+
KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"`
561562
}
562563

563564
type Cron struct {
@@ -1142,7 +1143,7 @@ func (cfg *Config) AdjustCheckPoint() {
11421143
MaxAllowedPacket: defaultMaxAllowedPacket,
11431144
TLS: cfg.TiDB.TLS,
11441145
}
1145-
cfg.Checkpoint.DSN = param.ToDSN()
1146+
cfg.Checkpoint.MySQLParam = &param
11461147
case CheckpointDriverFile:
11471148
cfg.Checkpoint.DSN = "/tmp/" + cfg.Checkpoint.Schema + ".pb"
11481149
}

br/pkg/lightning/config/config_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import (
3232
"github.com/BurntSushi/toml"
3333
"github.com/pingcap/tidb/br/pkg/lightning/common"
3434
"github.com/pingcap/tidb/br/pkg/lightning/config"
35-
"github.com/pingcap/tidb/parser/mysql"
3635
"github.com/stretchr/testify/require"
3736
)
3837

@@ -626,7 +625,9 @@ func TestLoadConfig(t *testing.T) {
626625
taskCfg.TiDB.DistSQLScanConcurrency = 1
627626
err = taskCfg.Adjust(context.Background())
628627
require.NoError(t, err)
629-
require.Equal(t, "guest:12345@tcp(172.16.30.11:4001)/?charset=utf8mb4&sql_mode='"+mysql.DefaultSQLMode+"'&maxAllowedPacket=67108864&tls=false", taskCfg.Checkpoint.DSN)
628+
equivalentDSN := taskCfg.Checkpoint.MySQLParam.ToDriverConfig().FormatDSN()
629+
expectedDSN := "guest:12345@tcp(172.16.30.11:4001)/?tls=false&maxAllowedPacket=67108864&charset=utf8mb4&sql_mode=%27ONLY_FULL_GROUP_BY%2CSTRICT_TRANS_TABLES%2CNO_ZERO_IN_DATE%2CNO_ZERO_DATE%2CERROR_FOR_DIVISION_BY_ZERO%2CNO_AUTO_CREATE_USER%2CNO_ENGINE_SUBSTITUTION%27"
630+
require.Equal(t, expectedDSN, equivalentDSN)
630631

631632
result := taskCfg.String()
632633
require.Regexp(t, `.*"pd-addr":"172.16.30.11:2379,172.16.30.12:2379".*`, result)

cmd/importer/db.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
"strconv"
2323
"strings"
2424

25-
_ "github.com/go-sql-driver/mysql"
25+
mysql2 "github.com/go-sql-driver/mysql"
2626
"github.com/pingcap/errors"
2727
"github.com/pingcap/log"
2828
"github.com/pingcap/tidb/parser/mysql"
@@ -318,13 +318,18 @@ func execSQL(db *sql.DB, sql string) error {
318318
}
319319

320320
func createDB(cfg DBConfig) (*sql.DB, error) {
321-
dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name)
322-
db, err := sql.Open("mysql", dbDSN)
321+
driverCfg := mysql2.NewConfig()
322+
driverCfg.User = cfg.User
323+
driverCfg.Passwd = cfg.Password
324+
driverCfg.Net = "tcp"
325+
driverCfg.Addr = cfg.Host + ":" + strconv.Itoa(cfg.Port)
326+
driverCfg.DBName = cfg.Name
327+
328+
c, err := mysql2.NewConnector(driverCfg)
323329
if err != nil {
324330
return nil, errors.Trace(err)
325331
}
326-
327-
return db, nil
332+
return sql.OpenDB(c), nil
328333
}
329334

330335
func closeDB(db *sql.DB) error {

dumpling/export/config.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,31 @@ func (conf *Config) GetDSN(db string) string {
218218
return dsn
219219
}
220220

221+
// GetDriverConfig returns the MySQL driver config from Config.
222+
func (conf *Config) GetDriverConfig(db string) *mysql.Config {
223+
driverCfg := mysql.NewConfig()
224+
// maxAllowedPacket=0 can be used to automatically fetch the max_allowed_packet variable from server on every connection.
225+
// https://github.com/go-sql-driver/mysql#maxallowedpacket
226+
hostPort := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port))
227+
driverCfg.User = conf.User
228+
driverCfg.Passwd = conf.Password
229+
driverCfg.Net = "tcp"
230+
driverCfg.Addr = hostPort
231+
driverCfg.DBName = db
232+
driverCfg.Collation = "utf8mb4_general_ci"
233+
driverCfg.ReadTimeout = conf.ReadTimeout
234+
driverCfg.WriteTimeout = 30 * time.Second
235+
driverCfg.InterpolateParams = true
236+
driverCfg.MaxAllowedPacket = 0
237+
if conf.Security.DriveTLSName != "" {
238+
driverCfg.TLSConfig = conf.Security.DriveTLSName
239+
}
240+
if conf.AllowCleartextPasswords {
241+
driverCfg.AllowCleartextPasswords = true
242+
}
243+
return driverCfg
244+
}
245+
221246
func timestampDirName() string {
222247
return fmt.Sprintf("./export-%s", time.Now().Format(time.RFC3339))
223248
}

0 commit comments

Comments
 (0)