From e44b8a0451b8b56e6fc3687e2d38d5d476901a3a Mon Sep 17 00:00:00 2001 From: mxlxm Date: Fri, 1 Sep 2017 18:10:13 +0800 Subject: [PATCH 1/2] support client collation - since charset would issue additional queries "SET NAMES xxx", use collation is better. see:https://github.com/go-sql-driver/mysql#charset - use cli default-character-set arg value as session charset. --- mysql/charset.go | 15 ++++++++------- server/conn.go | 9 ++++++--- server/conn_test.go | 31 +++++++++++++++++++++++++++++++ server/driver_tidb.go | 6 ++++++ session.go | 8 ++++++++ util/charset/charset.go | 14 ++++++++++++++ 6 files changed, 73 insertions(+), 10 deletions(-) diff --git a/mysql/charset.go b/mysql/charset.go index dd388ce0c311..07c6b9defbac 100644 --- a/mysql/charset.go +++ b/mysql/charset.go @@ -551,13 +551,14 @@ var CollationNames = map[string]uint8{ // MySQL collation information. const ( - UTF8Charset = "utf8" - UTF8MB4Charset = "utf8mb4" - DefaultCharset = UTF8Charset - DefaultCollationID = 83 - BinaryCollationID = 63 - UTF8DefaultCollation = "utf8_bin" - DefaultCollationName = UTF8DefaultCollation + UTF8Charset = "utf8" + UTF8MB4Charset = "utf8mb4" + DefaultCharset = UTF8Charset + DefaultCollationID = 83 + UTF8MB4GeneralCollationID = 45 + BinaryCollationID = 63 + UTF8DefaultCollation = "utf8_bin" + DefaultCollationName = UTF8DefaultCollation ) // IsUTF8Charset checks if charset is utf8 or utf8mb4 diff --git a/server/conn.go b/server/conn.go index 62ae654d4612..b690214a294b 100644 --- a/server/conn.go +++ b/server/conn.go @@ -147,8 +147,11 @@ func (cc *clientConn) writeInitialHandshake() error { data = append(data, 0) // capability flag lower 2 bytes, using default capability here data = append(data, byte(defaultCapability), byte(defaultCapability>>8)) - // charset, utf-8 default - data = append(data, uint8(mysql.DefaultCollationID)) + // charset + if cc.collation == 0 { + cc.collation = uint8(mysql.DefaultCollationID) + } + data = append(data, cc.collation) //status data = append(data, dumpUint16(mysql.ServerStatusAutocommit)...) // below 13 byte may not be used @@ -204,7 +207,7 @@ func handshakeResponseFromData(packet *handshakeResponse41, data []byte) (err er pos += 4 // skip max packet size pos += 4 - // charset, skip, if you want to use another charset, use set names + // charset packet.Collation = data[pos] pos++ // skip reserved 23[00] diff --git a/server/conn_test.go b/server/conn_test.go index 8a353c21e73a..1199c1e0f60c 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -132,6 +132,37 @@ func (ts ConnTestSuite) TestInitialHandshake(c *C) { c.Assert(outBuffer.Bytes()[4:], DeepEquals, expected.Bytes()) } +func (ts ConnTestSuite) TestInitialHandshakeWithCollation(c *C) { + c.Parallel() + var outBuffer bytes.Buffer + cc := &clientConn{ + connectionID: 1, + salt: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10}, + collation: mysql.UTF8MB4GeneralCollationID, // utf8mb4 collate utf8mb4_general_ci + pkt: &packetIO{ + wb: bufio.NewWriter(&outBuffer), + }, + } + err := cc.writeInitialHandshake() + c.Assert(err, IsNil) + + expected := new(bytes.Buffer) + expected.WriteByte(0x0a) // Protocol + expected.WriteString(mysql.ServerVersion) // Version + expected.WriteByte(0x00) // NULL + binary.Write(expected, binary.LittleEndian, int32(1)) // Connection ID + expected.Write([]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00}) // Salt + binary.Write(expected, binary.LittleEndian, int16(defaultCapability&0xFFFF)) // Server Capability + expected.WriteByte(uint8(mysql.UTF8MB4GeneralCollationID)) // Server Language + binary.Write(expected, binary.LittleEndian, mysql.ServerStatusAutocommit) // Server Status + binary.Write(expected, binary.LittleEndian, int16((defaultCapability>>16)&0xFFFF)) // Extended Server Capability + expected.WriteByte(0x15) // Authentication Plugin Length + expected.Write([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) // Unused + expected.Write([]byte{0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x00}) // Salt + expected.WriteString("mysql_native_password") // Authentication Plugin + expected.WriteByte(0x00) // NULL + c.Assert(outBuffer.Bytes()[4:], DeepEquals, expected.Bytes()) +} func mapIdentical(m1, m2 map[string]string) bool { return mapBelong(m1, m2) && mapBelong(m2, m1) } diff --git a/server/driver_tidb.go b/server/driver_tidb.go index fa4217278e01..407fbc3df4a9 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/auth" + "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" ) @@ -128,6 +129,11 @@ func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, if err != nil { return nil, errors.Trace(err) } + var cs, co string + if cs, co, err = charset.GetCharsetInfoByID(int(collation)); err != nil { + return nil, errors.Trace(err) + } + session.SetCharset(cs, co) session.SetClientCapability(capability) session.SetConnectionID(connID) tc := &TiDBContext{ diff --git a/session.go b/session.go index e13904c137bb..bdf3c3e31bb6 100644 --- a/session.go +++ b/session.go @@ -71,6 +71,7 @@ type Session interface { DropPreparedStmt(stmtID uint32) error SetClientCapability(uint32) // Set client capability flags. SetConnectionID(uint64) + SetCharset(cs, co string) SetSessionManager(util.SessionManager) Close() Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool @@ -181,6 +182,13 @@ func (s *session) SetConnectionID(connectionID uint64) { s.sessionVars.ConnectionID = connectionID } +func (s *session) SetCharset(cs, co string) { + for _, v := range variable.SetNamesVariables { + s.sessionVars.Systems[v] = cs + } + s.sessionVars.Systems[variable.CollationConnection] = co +} + func (s *session) SetSessionManager(sm util.SessionManager) { s.sessionManager = sm } diff --git a/util/charset/charset.go b/util/charset/charset.go index dce4746fee40..4f63be5e2abd 100644 --- a/util/charset/charset.go +++ b/util/charset/charset.go @@ -17,6 +17,7 @@ import ( "strings" "github.com/juju/errors" + "github.com/pingcap/tidb/mysql" ) // Charset is a charset. @@ -151,6 +152,19 @@ func GetCharsetDesc(cs string) (*Desc, error) { return desc, nil } +// GetCharsetInfoByID returns charset and collation for id as cs_number. +func GetCharsetInfoByID(id int) (string, string, error) { + if id == mysql.DefaultCollationID { + return mysql.DefaultCharset, mysql.DefaultCollationName, nil + } + for _, collation := range collations { + if id == collation.ID { + return collation.CharsetName, collation.Name, nil + } + } + return "", "", errors.Errorf("Unknown charset id %d", id) +} + // GetCollations returns a list for all collations. func GetCollations() []*Collation { return collations From 3bd38ae469f413bcdcf26b6e5142db112a62aba7 Mon Sep 17 00:00:00 2001 From: mxlxm Date: Mon, 4 Sep 2017 23:00:19 +0800 Subject: [PATCH 2/2] address comment --- mysql/charset.go | 15 +++++++-------- server/conn_test.go | 31 ------------------------------- server/driver_tidb.go | 6 ++---- server/server_test.go | 35 +++++++++++++++++++++++++++++++++++ server/tidb_test.go | 5 +++++ session.go | 10 ++++++++-- util/charset/charset.go | 8 ++++---- 7 files changed, 61 insertions(+), 49 deletions(-) diff --git a/mysql/charset.go b/mysql/charset.go index 07c6b9defbac..dd388ce0c311 100644 --- a/mysql/charset.go +++ b/mysql/charset.go @@ -551,14 +551,13 @@ var CollationNames = map[string]uint8{ // MySQL collation information. const ( - UTF8Charset = "utf8" - UTF8MB4Charset = "utf8mb4" - DefaultCharset = UTF8Charset - DefaultCollationID = 83 - UTF8MB4GeneralCollationID = 45 - BinaryCollationID = 63 - UTF8DefaultCollation = "utf8_bin" - DefaultCollationName = UTF8DefaultCollation + UTF8Charset = "utf8" + UTF8MB4Charset = "utf8mb4" + DefaultCharset = UTF8Charset + DefaultCollationID = 83 + BinaryCollationID = 63 + UTF8DefaultCollation = "utf8_bin" + DefaultCollationName = UTF8DefaultCollation ) // IsUTF8Charset checks if charset is utf8 or utf8mb4 diff --git a/server/conn_test.go b/server/conn_test.go index 1199c1e0f60c..8a353c21e73a 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -132,37 +132,6 @@ func (ts ConnTestSuite) TestInitialHandshake(c *C) { c.Assert(outBuffer.Bytes()[4:], DeepEquals, expected.Bytes()) } -func (ts ConnTestSuite) TestInitialHandshakeWithCollation(c *C) { - c.Parallel() - var outBuffer bytes.Buffer - cc := &clientConn{ - connectionID: 1, - salt: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10}, - collation: mysql.UTF8MB4GeneralCollationID, // utf8mb4 collate utf8mb4_general_ci - pkt: &packetIO{ - wb: bufio.NewWriter(&outBuffer), - }, - } - err := cc.writeInitialHandshake() - c.Assert(err, IsNil) - - expected := new(bytes.Buffer) - expected.WriteByte(0x0a) // Protocol - expected.WriteString(mysql.ServerVersion) // Version - expected.WriteByte(0x00) // NULL - binary.Write(expected, binary.LittleEndian, int32(1)) // Connection ID - expected.Write([]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00}) // Salt - binary.Write(expected, binary.LittleEndian, int16(defaultCapability&0xFFFF)) // Server Capability - expected.WriteByte(uint8(mysql.UTF8MB4GeneralCollationID)) // Server Language - binary.Write(expected, binary.LittleEndian, mysql.ServerStatusAutocommit) // Server Status - binary.Write(expected, binary.LittleEndian, int16((defaultCapability>>16)&0xFFFF)) // Extended Server Capability - expected.WriteByte(0x15) // Authentication Plugin Length - expected.Write([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) // Unused - expected.Write([]byte{0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x00}) // Salt - expected.WriteString("mysql_native_password") // Authentication Plugin - expected.WriteByte(0x00) // NULL - c.Assert(outBuffer.Bytes()[4:], DeepEquals, expected.Bytes()) -} func mapIdentical(m1, m2 map[string]string) bool { return mapBelong(m1, m2) && mapBelong(m2, m1) } diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 407fbc3df4a9..40a0e1f87bb5 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/auth" - "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" ) @@ -129,11 +128,10 @@ func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, if err != nil { return nil, errors.Trace(err) } - var cs, co string - if cs, co, err = charset.GetCharsetInfoByID(int(collation)); err != nil { + err = session.SetCollation(int(collation)) + if err != nil { return nil, errors.Trace(err) } - session.SetCharset(cs, co) session.SetClientCapability(capability) session.SetConnectionID(connID) tc := &TiDBContext{ diff --git a/server/server_test.go b/server/server_test.go index df31102f3b0d..407607d50504 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -246,6 +246,41 @@ func runTestSpecialType(t *C) { }) } +func runTestClientWithCollation(t *C) { + runTests(t, func(config *mysql.Config) { + config.Collation = "utf8mb4_general_ci" + }, func(dbt *DBTest) { + var name, charset, collation string + // check session variable collation_connection + rows := dbt.mustQuery("show variables like 'collation_connection'") + t.Assert(rows.Next(), IsTrue) + err := rows.Scan(&name, &collation) + t.Assert(err, IsNil) + t.Assert(collation, Equals, "utf8mb4_general_ci") + + // check session variable character_set_client + rows = dbt.mustQuery("show variables like 'character_set_client'") + t.Assert(rows.Next(), IsTrue) + err = rows.Scan(&name, &charset) + t.Assert(err, IsNil) + t.Assert(charset, Equals, "utf8mb4") + + // check session variable character_set_results + rows = dbt.mustQuery("show variables like 'character_set_results'") + t.Assert(rows.Next(), IsTrue) + err = rows.Scan(&name, &charset) + t.Assert(err, IsNil) + t.Assert(charset, Equals, "utf8mb4") + + // check session variable character_set_connection + rows = dbt.mustQuery("show variables like 'character_set_connection'") + t.Assert(rows.Next(), IsTrue) + err = rows.Scan(&name, &charset) + t.Assert(err, IsNil) + t.Assert(charset, Equals, "utf8mb4") + }) +} + func runTestPreparedString(t *C) { runTestsOnNewDB(t, nil, "PreparedString", func(dbt *DBTest) { dbt.mustExec("create table test (a char(10), b char(10))") diff --git a/server/tidb_test.go b/server/tidb_test.go index 80f3559c92c9..05573d3b7c0a 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -152,3 +152,8 @@ func (ts *TidbTestSuite) TestSocket(c *C) { config.Strict = true }, "SocketRegression") } + +func (ts *TidbTestSuite) TestClientWithCollation(c *C) { + c.Parallel() + runTestClientWithCollation(c) +} diff --git a/session.go b/session.go index bdf3c3e31bb6..4be893858bf1 100644 --- a/session.go +++ b/session.go @@ -49,6 +49,7 @@ import ( "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/auth" + "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" "github.com/pingcap/tipb/go-binlog" goctx "golang.org/x/net/context" @@ -71,7 +72,7 @@ type Session interface { DropPreparedStmt(stmtID uint32) error SetClientCapability(uint32) // Set client capability flags. SetConnectionID(uint64) - SetCharset(cs, co string) + SetCollation(coID int) error SetSessionManager(util.SessionManager) Close() Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool @@ -182,11 +183,16 @@ func (s *session) SetConnectionID(connectionID uint64) { s.sessionVars.ConnectionID = connectionID } -func (s *session) SetCharset(cs, co string) { +func (s *session) SetCollation(coID int) error { + cs, co, err := charset.GetCharsetInfoByID(coID) + if err != nil { + return errors.Trace(err) + } for _, v := range variable.SetNamesVariables { s.sessionVars.Systems[v] = cs } s.sessionVars.Systems[variable.CollationConnection] = co + return nil } func (s *session) SetSessionManager(sm util.SessionManager) { diff --git a/util/charset/charset.go b/util/charset/charset.go index 4f63be5e2abd..78dee7eb0d5c 100644 --- a/util/charset/charset.go +++ b/util/charset/charset.go @@ -153,16 +153,16 @@ func GetCharsetDesc(cs string) (*Desc, error) { } // GetCharsetInfoByID returns charset and collation for id as cs_number. -func GetCharsetInfoByID(id int) (string, string, error) { - if id == mysql.DefaultCollationID { +func GetCharsetInfoByID(coID int) (string, string, error) { + if coID == mysql.DefaultCollationID { return mysql.DefaultCharset, mysql.DefaultCollationName, nil } for _, collation := range collations { - if id == collation.ID { + if coID == collation.ID { return collation.CharsetName, collation.Name, nil } } - return "", "", errors.Errorf("Unknown charset id %d", id) + return "", "", errors.Errorf("Unknown charset id %d", coID) } // GetCollations returns a list for all collations.