@@ -23,7 +23,6 @@ import (
2323 "io"
2424 "net"
2525 "net/http"
26- "net/url"
2726 "os"
2827 "strconv"
2928 "strings"
@@ -58,28 +57,38 @@ type MySQLConnectParam struct {
5857 Vars map [string ]string
5958}
6059
61- func (param * MySQLConnectParam ) ToDSN () string {
62- hostPort := net .JoinHostPort (param .Host , strconv .Itoa (param .Port ))
63- dsn := fmt .Sprintf ("%s:%s@tcp(%s)/?charset=utf8mb4&sql_mode='%s'&maxAllowedPacket=%d&tls=%s" ,
64- param .User , param .Password , hostPort ,
65- param .SQLMode , param .MaxAllowedPacket , param .TLS )
60+ func (param * MySQLConnectParam ) ToDriverConfig () * mysql.Config {
61+ cfg := mysql .NewConfig ()
62+ cfg .Params = make (map [string ]string )
63+
64+ cfg .User = param .User
65+ cfg .Passwd = param .Password
66+ cfg .Net = "tcp"
67+ cfg .Addr = net .JoinHostPort (param .Host , strconv .Itoa (param .Port ))
68+ cfg .Params ["charset" ] = "utf8mb4"
69+ cfg .Params ["sql_mode" ] = fmt .Sprintf ("'%s'" , param .SQLMode )
70+ cfg .MaxAllowedPacket = int (param .MaxAllowedPacket )
71+ cfg .TLSConfig = param .TLS
6672
6773 for k , v := range param .Vars {
68- dsn + = fmt .Sprintf ("&%s= '%s'" , k , url . QueryEscape ( v ) )
74+ cfg . Params [ k ] = fmt .Sprintf ("'%s'" , v )
6975 }
70-
71- return dsn
76+ return cfg
7277}
7378
74- func tryConnectMySQL (dsn string ) (* sql.DB , error ) {
75- driverName := "mysql"
76- failpoint .Inject ("MockMySQLDriver" , func (val failpoint.Value ) {
77- driverName = val .(string )
79+ func tryConnectMySQL (cfg * mysql.Config ) (* sql.DB , error ) {
80+ failpoint .Inject ("MustMySQLPassword" , func (val failpoint.Value ) {
81+ pwd := val .(string )
82+ if cfg .Passwd != pwd {
83+ failpoint .Return (nil , & mysql.MySQLError {Number : tmysql .ErrAccessDenied , Message : "access denied" })
84+ }
85+ failpoint .Return (nil , nil )
7886 })
79- db , err := sql . Open ( driverName , dsn )
87+ c , err := mysql . NewConnector ( cfg )
8088 if err != nil {
8189 return nil , errors .Trace (err )
8290 }
91+ db := sql .OpenDB (c )
8392 if err = db .Ping (); err != nil {
8493 _ = db .Close ()
8594 return nil , errors .Trace (err )
@@ -89,13 +98,9 @@ func tryConnectMySQL(dsn string) (*sql.DB, error) {
8998
9099// ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding,
91100// we will try to connect MySQL with the base64 decoding of the password.
92- func ConnectMySQL (dsn string ) (* sql.DB , error ) {
93- cfg , err := mysql .ParseDSN (dsn )
94- if err != nil {
95- return nil , errors .Trace (err )
96- }
101+ func ConnectMySQL (cfg * mysql.Config ) (* sql.DB , error ) {
97102 // Try plain password first.
98- db , firstErr := tryConnectMySQL (dsn )
103+ db , firstErr := tryConnectMySQL (cfg )
99104 if firstErr == nil {
100105 return db , nil
101106 }
@@ -104,9 +109,9 @@ func ConnectMySQL(dsn string) (*sql.DB, error) {
104109 // If password is encoded by base64, try the decoded string as well.
105110 if password , decodeErr := base64 .StdEncoding .DecodeString (cfg .Passwd ); decodeErr == nil && string (password ) != cfg .Passwd {
106111 cfg .Passwd = string (password )
107- db , err = tryConnectMySQL (cfg . FormatDSN () )
112+ db2 , err : = tryConnectMySQL (cfg )
108113 if err == nil {
109- return db , nil
114+ return db2 , nil
110115 }
111116 }
112117 }
@@ -115,7 +120,7 @@ func ConnectMySQL(dsn string) (*sql.DB, error) {
115120}
116121
117122func (param * MySQLConnectParam ) Connect () (* sql.DB , error ) {
118- db , err := ConnectMySQL (param .ToDSN ())
123+ db , err := ConnectMySQL (param .ToDriverConfig ())
119124 if err != nil {
120125 return nil , errors .Trace (err )
121126 }
0 commit comments