Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support client specified collation #4409

Merged
merged 6 commits into from Sep 5, 2017
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions server/conn.go
Expand Up @@ -147,8 +147,11 @@ func (cc *clientConn) writeInitialHandshake() error {
data = append(data, 0)
// capability flag lower 2 bytes, using default capability here
data = append(data, byte(defaultCapability), byte(defaultCapability>>8))
// charset, utf-8 default
data = append(data, uint8(mysql.DefaultCollationID))
// charset
if cc.collation == 0 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cc.collation is initialized to mysql.DefaultCollationID in newConn already.
So this branch never run.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, but TestInitialHandshake in conn_test.go initialized a clientConn without collation field, and then called cc.writeInitialHandshake(), without this branch the case would fail.

cc.collation = uint8(mysql.DefaultCollationID)
}
data = append(data, cc.collation)
//status
data = append(data, dumpUint16(mysql.ServerStatusAutocommit)...)
// below 13 byte may not be used
Expand Down Expand Up @@ -204,7 +207,7 @@ func handshakeResponseFromData(packet *handshakeResponse41, data []byte) (err er
pos += 4
// skip max packet size
pos += 4
// charset, skip, if you want to use another charset, use set names
// charset
packet.Collation = data[pos]
pos++
// skip reserved 23[00]
Expand Down
4 changes: 4 additions & 0 deletions server/driver_tidb.go
Expand Up @@ -128,6 +128,10 @@ func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8,
if err != nil {
return nil, errors.Trace(err)
}
err = session.SetCollation(int(collation))
if err != nil {
return nil, errors.Trace(err)
}
session.SetClientCapability(capability)
session.SetConnectionID(connID)
tc := &TiDBContext{
Expand Down
35 changes: 35 additions & 0 deletions server/server_test.go
Expand Up @@ -246,6 +246,41 @@ func runTestSpecialType(t *C) {
})
}

func runTestClientWithCollation(t *C) {
runTests(t, func(config *mysql.Config) {
config.Collation = "utf8mb4_general_ci"
}, func(dbt *DBTest) {
var name, charset, collation string
// check session variable collation_connection
rows := dbt.mustQuery("show variables like 'collation_connection'")
t.Assert(rows.Next(), IsTrue)
err := rows.Scan(&name, &collation)
t.Assert(err, IsNil)
t.Assert(collation, Equals, "utf8mb4_general_ci")

// check session variable character_set_client
rows = dbt.mustQuery("show variables like 'character_set_client'")
t.Assert(rows.Next(), IsTrue)
err = rows.Scan(&name, &charset)
t.Assert(err, IsNil)
t.Assert(charset, Equals, "utf8mb4")

// check session variable character_set_results
rows = dbt.mustQuery("show variables like 'character_set_results'")
t.Assert(rows.Next(), IsTrue)
err = rows.Scan(&name, &charset)
t.Assert(err, IsNil)
t.Assert(charset, Equals, "utf8mb4")

// check session variable character_set_connection
rows = dbt.mustQuery("show variables like 'character_set_connection'")
t.Assert(rows.Next(), IsTrue)
err = rows.Scan(&name, &charset)
t.Assert(err, IsNil)
t.Assert(charset, Equals, "utf8mb4")
})
}

func runTestPreparedString(t *C) {
runTestsOnNewDB(t, nil, "PreparedString", func(dbt *DBTest) {
dbt.mustExec("create table test (a char(10), b char(10))")
Expand Down
5 changes: 5 additions & 0 deletions server/tidb_test.go
Expand Up @@ -152,3 +152,8 @@ func (ts *TidbTestSuite) TestSocket(c *C) {
config.Strict = true
}, "SocketRegression")
}

func (ts *TidbTestSuite) TestClientWithCollation(c *C) {
c.Parallel()
runTestClientWithCollation(c)
}
14 changes: 14 additions & 0 deletions session.go
Expand Up @@ -49,6 +49,7 @@ import (
"github.com/pingcap/tidb/store/tikv/oracle"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/auth"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/types"
"github.com/pingcap/tipb/go-binlog"
goctx "golang.org/x/net/context"
Expand All @@ -71,6 +72,7 @@ type Session interface {
DropPreparedStmt(stmtID uint32) error
SetClientCapability(uint32) // Set client capability flags.
SetConnectionID(uint64)
SetCollation(coID int) error
SetSessionManager(util.SessionManager)
Close()
Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool
Expand Down Expand Up @@ -181,6 +183,18 @@ func (s *session) SetConnectionID(connectionID uint64) {
s.sessionVars.ConnectionID = connectionID
}

func (s *session) SetCollation(coID int) error {
cs, co, err := charset.GetCharsetInfoByID(coID)
if err != nil {
return errors.Trace(err)
}
for _, v := range variable.SetNamesVariables {
s.sessionVars.Systems[v] = cs
}
s.sessionVars.Systems[variable.CollationConnection] = co
return nil
}

func (s *session) SetSessionManager(sm util.SessionManager) {
s.sessionManager = sm
}
Expand Down
14 changes: 14 additions & 0 deletions util/charset/charset.go
Expand Up @@ -17,6 +17,7 @@ import (
"strings"

"github.com/juju/errors"
"github.com/pingcap/tidb/mysql"
)

// Charset is a charset.
Expand Down Expand Up @@ -151,6 +152,19 @@ func GetCharsetDesc(cs string) (*Desc, error) {
return desc, nil
}

// GetCharsetInfoByID returns charset and collation for id as cs_number.
func GetCharsetInfoByID(coID int) (string, string, error) {
if coID == mysql.DefaultCollationID {
return mysql.DefaultCharset, mysql.DefaultCollationName, nil
}
for _, collation := range collations {
if coID == collation.ID {
return collation.CharsetName, collation.Name, nil
}
}
return "", "", errors.Errorf("Unknown charset id %d", coID)
}

// GetCollations returns a list for all collations.
func GetCollations() []*Collation {
return collations
Expand Down