Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// The MIT License (MIT)
//
// Copyright (c) 2014 wandoulabs
// Copyright (c) 2014 siddontang
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
package server
import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
goerr "errors"
"fmt"
"io"
"net"
"os/user"
"runtime/pprof"
"runtime/trace"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/klauspost/compress/zstd"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/domain/infosync"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/extension"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/plugin"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/privilege/conn"
"github.com/pingcap/tidb/privilege/privileges/ldap"
"github.com/pingcap/tidb/server/internal/column"
"github.com/pingcap/tidb/server/internal/dump"
util2 "github.com/pingcap/tidb/server/internal/util"
server_metrics "github.com/pingcap/tidb/server/metrics"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/sessiontxn"
storeerr "github.com/pingcap/tidb/store/driver/error"
"github.com/pingcap/tidb/tablecodec"
tidbutil "github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/arena"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/dbterror/exeerrors"
"github.com/pingcap/tidb/util/execdetails"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/memory"
tlsutil "github.com/pingcap/tidb/util/tls"
topsqlstate "github.com/pingcap/tidb/util/topsql/state"
"github.com/pingcap/tidb/util/tracing"
"github.com/prometheus/client_golang/prometheus"
"github.com/tikv/client-go/v2/util"
"go.uber.org/zap"
)
const (
connStatusDispatching int32 = iota
connStatusReading
connStatusShutdown = variable.ConnStatusShutdown // Closed by server.
connStatusWaitShutdown = 3 // Notified by server to close.
)
// newClientConn creates a *clientConn object.
func newClientConn(s *Server) *clientConn {
return &clientConn{
server: s,
connectionID: s.dom.NextConnID(),
collation: mysql.DefaultCollationID,
alloc: arena.NewAllocator(32 * 1024),
chunkAlloc: chunk.NewAllocator(),
status: connStatusDispatching,
lastActive: time.Now(),
authPlugin: mysql.AuthNativePassword,
quit: make(chan struct{}),
ppEnabled: s.cfg.ProxyProtocol.Networks != "",
}
}
// clientConn represents a connection between server and client, it maintains connection specific state,
// handles client query.
type clientConn struct {
pkt *packetIO // a helper to read and write data in packet format.
bufReadConn *bufferedReadConn // a buffered-read net.Conn or buffered-read tls.Conn.
tlsConn *tls.Conn // TLS connection, nil if not TLS.
server *Server // a reference of server instance.
capability uint32 // client capability affects the way server handles client request.
connectionID uint64 // atomically allocated by a global variable, unique in process scope.
user string // user of the client.
dbname string // default database name.
salt []byte // random bytes used for authentication.
alloc arena.Allocator // an memory allocator for reducing memory allocation.
chunkAlloc chunk.Allocator
lastPacket []byte // latest sql query string, currently used for logging error.
// ShowProcess() and mysql.ComChangeUser both visit this field, ShowProcess() read information through
// the TiDBContext and mysql.ComChangeUser re-create it, so a lock is required here.
ctx struct {
sync.RWMutex
*TiDBContext // an interface to execute sql statements.
}
attrs map[string]string // attributes parsed from client handshake response.
serverHost string // server host
peerHost string // peer host
peerPort string // peer port
status int32 // dispatching/reading/shutdown/waitshutdown
lastCode uint16 // last error code
collation uint8 // collation used by client, may be different from the collation used by database.
lastActive time.Time // last active time
authPlugin string // default authentication plugin
isUnixSocket bool // connection is Unix Socket file
rsEncoder *column.ResultEncoder // rsEncoder is used to encode the string result to different charsets.
inputDecoder *util2.InputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8.
socketCredUID uint32 // UID from the other end of the Unix Socket
// mu is used for cancelling the execution of current transaction.
mu struct {
sync.RWMutex
cancelFunc context.CancelFunc
}
// quit is close once clientConn quit Run().
quit chan struct{}
extensions *extension.SessionExtensions
// Proxy Protocol Enabled
ppEnabled bool
}
func (cc *clientConn) getCtx() *TiDBContext {
cc.ctx.RLock()
defer cc.ctx.RUnlock()
return cc.ctx.TiDBContext
}
func (cc *clientConn) setCtx(ctx *TiDBContext) {
cc.ctx.Lock()
cc.ctx.TiDBContext = ctx
cc.ctx.Unlock()
}
func (cc *clientConn) String() string {
collationStr := mysql.Collations[cc.collation]
return fmt.Sprintf("id:%d, addr:%s status:%b, collation:%s, user:%s",
cc.connectionID, cc.bufReadConn.RemoteAddr(), cc.ctx.Status(), collationStr, cc.user,
)
}
func (cc *clientConn) setStatus(status int32) {
atomic.StoreInt32(&cc.status, status)
if ctx := cc.getCtx(); ctx != nil {
atomic.StoreInt32(&ctx.GetSessionVars().ConnectionStatus, status)
}
}
func (cc *clientConn) getStatus() int32 {
return atomic.LoadInt32(&cc.status)
}
func (cc *clientConn) CompareAndSwapStatus(oldStatus, newStatus int32) bool {
return atomic.CompareAndSwapInt32(&cc.status, oldStatus, newStatus)
}
// authSwitchRequest is used by the server to ask the client to switch to a different authentication
// plugin. MySQL 8.0 libmysqlclient based clients by default always try `caching_sha2_password`, even
// when the server advertises the its default to be `mysql_native_password`. In addition to this switching
// may be needed on a per user basis as the authentication method is set per user.
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html
// https://bugs.mysql.com/bug.php?id=93044
func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]byte, error) {
clientPlugin := plugin
if plugin == mysql.AuthLDAPSASL {
clientPlugin += "_client"
} else if plugin == mysql.AuthLDAPSimple {
clientPlugin = mysql.AuthMySQLClearPassword
}
failpoint.Inject("FakeAuthSwitch", func() {
failpoint.Return([]byte(clientPlugin), nil)
})
enclen := 1 + len(clientPlugin) + 1 + len(cc.salt) + 1
data := cc.alloc.AllocWithLen(4, enclen)
data = append(data, mysql.AuthSwitchRequest) // switch request
data = append(data, []byte(clientPlugin)...)
data = append(data, byte(0x00)) // requires null
if plugin == mysql.AuthLDAPSASL {
// append sasl auth method name
data = append(data, []byte(ldap.LDAPSASLAuthImpl.GetSASLAuthMethod())...)
data = append(data, byte(0x00))
} else {
data = append(data, cc.salt...)
data = append(data, 0)
}
err := cc.writePacket(data)
if err != nil {
logutil.Logger(ctx).Debug("write response to client failed", zap.Error(err))
return nil, err
}
err = cc.flush(ctx)
if err != nil {
logutil.Logger(ctx).Debug("flush response to client failed", zap.Error(err))
return nil, err
}
resp, err := cc.readPacket()
if err != nil {
err = errors.SuspendStack(err)
if errors.Cause(err) == io.EOF {
logutil.Logger(ctx).Warn("authSwitchRequest response fail due to connection has be closed by client-side")
} else {
logutil.Logger(ctx).Warn("authSwitchRequest response fail", zap.Error(err))
}
return nil, err
}
cc.authPlugin = plugin
return resp, nil
}
// handshake works like TCP handshake, but in a higher level, it first writes initial packet to client,
// during handshake, client and server negotiate compatible features and do authentication.
// After handshake, client can send sql query to server.
func (cc *clientConn) handshake(ctx context.Context) error {
if err := cc.writeInitialHandshake(ctx); err != nil {
if errors.Cause(err) == io.EOF {
logutil.Logger(ctx).Debug("Could not send handshake due to connection has be closed by client-side")
} else {
logutil.Logger(ctx).Debug("Write init handshake to client fail", zap.Error(errors.SuspendStack(err)))
}
return err
}
if err := cc.readOptionalSSLRequestAndHandshakeResponse(ctx); err != nil {
err1 := cc.writeError(ctx, err)
if err1 != nil {
logutil.Logger(ctx).Debug("writeError failed", zap.Error(err1))
}
return err
}
// MySQL supports an "init_connect" query, which can be run on initial connection.
// The query must return a non-error or the client is disconnected.
if err := cc.initConnect(ctx); err != nil {
logutil.Logger(ctx).Warn("init_connect failed", zap.Error(err))
initErr := errNewAbortingConnection.FastGenByArgs(cc.connectionID, "unconnected", cc.user, cc.peerHost, "init_connect command failed")
if err1 := cc.writeError(ctx, initErr); err1 != nil {
terror.Log(err1)
}
return initErr
}
data := cc.alloc.AllocWithLen(4, 32)
data = append(data, mysql.OKHeader)
data = append(data, 0, 0)
if cc.capability&mysql.ClientProtocol41 > 0 {
data = dump.Uint16(data, mysql.ServerStatusAutocommit)
data = append(data, 0, 0)
}
err := cc.writePacket(data)
cc.pkt.sequence = 0
if err != nil {
err = errors.SuspendStack(err)
logutil.Logger(ctx).Debug("write response to client failed", zap.Error(err))
return err
}
err = cc.flush(ctx)
if err != nil {
err = errors.SuspendStack(err)
logutil.Logger(ctx).Debug("flush response to client failed", zap.Error(err))
return err
}
// With mysql --compression-algorithms=zlib,zstd both flags are set, the result is Zlib
if cc.capability&mysql.ClientCompress > 0 {
cc.pkt.SetCompressionAlgorithm(mysql.CompressionZlib)
} else if cc.capability&mysql.ClientZstdCompressionAlgorithm > 0 {
cc.pkt.SetCompressionAlgorithm(mysql.CompressionZstd)
}
return err
}
func (cc *clientConn) Close() error {
cc.server.rwlock.Lock()
delete(cc.server.clients, cc.connectionID)
connections := len(cc.server.clients)
cc.server.rwlock.Unlock()
return closeConn(cc, connections)
}
func closeConn(cc *clientConn, connections int) error {
metrics.ConnGauge.Set(float64(connections))
cc.server.dom.ReleaseConnID(cc.connectionID)
if cc.bufReadConn != nil {
err := cc.bufReadConn.Close()
if err != nil {
// We need to expect connection might have already disconnected.
// This is because closeConn() might be called after a connection read-timeout.
logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err))
}
}
// Close statements and session
// This will release advisory locks, row locks, etc.
if ctx := cc.getCtx(); ctx != nil {
return ctx.Close()
}
return nil
}
func (cc *clientConn) closeWithoutLock() error {
delete(cc.server.clients, cc.connectionID)
return closeConn(cc, len(cc.server.clients))
}
// writeInitialHandshake sends server version, connection ID, server capability, collation, server status
// and auth salt to the client.
func (cc *clientConn) writeInitialHandshake(ctx context.Context) error {
data := make([]byte, 4, 128)
// min version 10
data = append(data, 10)
// server version[00]
data = append(data, mysql.ServerVersion...)
data = append(data, 0)
// connection id
data = append(data, byte(cc.connectionID), byte(cc.connectionID>>8), byte(cc.connectionID>>16), byte(cc.connectionID>>24))
// auth-plugin-data-part-1
data = append(data, cc.salt[0:8]...)
// filler [00]
data = append(data, 0)
// capability flag lower 2 bytes, using default capability here
data = append(data, byte(cc.server.capability), byte(cc.server.capability>>8))
// charset
if cc.collation == 0 {
cc.collation = uint8(mysql.DefaultCollationID)
}
data = append(data, cc.collation)
// status
data = dump.Uint16(data, mysql.ServerStatusAutocommit)
// below 13 byte may not be used
// capability flag upper 2 bytes, using default capability here
data = append(data, byte(cc.server.capability>>16), byte(cc.server.capability>>24))
// length of auth-plugin-data
data = append(data, byte(len(cc.salt)+1))
// reserved 10 [00]
data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
// auth-plugin-data-part-2
data = append(data, cc.salt[8:]...)
data = append(data, 0)
// auth-plugin name
if ctx := cc.getCtx(); ctx == nil {
if err := cc.openSession(); err != nil {
return err
}
}
defAuthPlugin, err := cc.ctx.GetSessionVars().GetGlobalSystemVar(context.Background(), variable.DefaultAuthPlugin)
if err != nil {
return err
}
cc.authPlugin = defAuthPlugin
data = append(data, []byte(defAuthPlugin)...)
// Close the session to force this to be re-opened after we parse the response. This is needed
// to ensure we use the collation and client flags from the response for the session.
if err = cc.ctx.Close(); err != nil {
return err
}
cc.setCtx(nil)
data = append(data, 0)
if err = cc.writePacket(data); err != nil {
return err
}
return cc.flush(ctx)
}
func (cc *clientConn) readPacket() ([]byte, error) {
if cc.getCtx() != nil {
cc.pkt.setMaxAllowedPacket(cc.ctx.GetSessionVars().MaxAllowedPacket)
}
return cc.pkt.readPacket()
}
func (cc *clientConn) writePacket(data []byte) error {
failpoint.Inject("FakeClientConn", func() {
if cc.pkt == nil {
failpoint.Return(nil)
}
})
return cc.pkt.writePacket(data)
}
// getSessionVarsWaitTimeout get session variable wait_timeout
func (cc *clientConn) getSessionVarsWaitTimeout(ctx context.Context) uint64 {
valStr, exists := cc.ctx.GetSessionVars().GetSystemVar(variable.WaitTimeout)
if !exists {
return variable.DefWaitTimeout
}
waitTimeout, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
logutil.Logger(ctx).Warn("get sysval wait_timeout failed, use default value", zap.Error(err))
// if get waitTimeout error, use default value
return variable.DefWaitTimeout
}
return waitTimeout
}
type handshakeResponse41 struct {
Capability uint32
Collation uint8
User string
DBName string
Auth []byte
AuthPlugin string
Attrs map[string]string
ZstdLevel zstd.EncoderLevel
}
// parseHandshakeResponseHeader parses the common header of SSLRequest and HandshakeResponse41.
func parseHandshakeResponseHeader(ctx context.Context, packet *handshakeResponse41, data []byte) (parsedBytes int, err error) {
// Ensure there are enough data to read:
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if len(data) < 4+4+1+23 {
logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data))
return 0, mysql.ErrMalformPacket
}
offset := 0
// capability
capability := binary.LittleEndian.Uint32(data[:4])
packet.Capability = capability
offset += 4
// skip max packet size
offset += 4
// charset, skip, if you want to use another charset, use set names
packet.Collation = data[offset]
offset++
// skip reserved 23[00]
offset += 23
return offset, nil
}
// parseHandshakeResponseBody parse the HandshakeResponse (except the common header part).
func parseHandshakeResponseBody(ctx context.Context, packet *handshakeResponse41, data []byte, offset int) (err error) {
defer func() {
// Check malformat packet cause out of range is disgusting, but don't panic!
if r := recover(); r != nil {
logutil.Logger(ctx).Error("handshake panic", zap.ByteString("packetData", data))
err = mysql.ErrMalformPacket
}
}()
// user name
packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)])
offset += len(packet.User) + 1
if packet.Capability&mysql.ClientPluginAuthLenencClientData > 0 {
// MySQL client sets the wrong capability, it will set this bit even server doesn't
// support ClientPluginAuthLenencClientData.
// https://github.com/mysql/mysql-server/blob/5.7/sql-common/client.c#L3478
if data[offset] == 0x1 { // No auth data
offset += 2
} else {
num, null, off := util2.ParseLengthEncodedInt(data[offset:])
offset += off
if !null {
packet.Auth = data[offset : offset+int(num)]
offset += int(num)
}
}
} else if packet.Capability&mysql.ClientSecureConnection > 0 {
// auth length and auth
authLen := int(data[offset])
offset++
packet.Auth = data[offset : offset+authLen]
offset += authLen
} else {
packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)]
offset += len(packet.Auth) + 1
}
if packet.Capability&mysql.ClientConnectWithDB > 0 {
if len(data[offset:]) > 0 {
idx := bytes.IndexByte(data[offset:], 0)
packet.DBName = string(data[offset : offset+idx])
offset += idx + 1
}
}
if packet.Capability&mysql.ClientPluginAuth > 0 {
idx := bytes.IndexByte(data[offset:], 0)
s := offset
f := offset + idx
if s < f { // handle unexpected bad packets
packet.AuthPlugin = string(data[s:f])
}
offset += idx + 1
}
if packet.Capability&mysql.ClientConnectAtts > 0 {
if len(data[offset:]) == 0 {
// Defend some ill-formated packet, connection attribute is not important and can be ignored.
return nil
}
if num, null, intOff := util2.ParseLengthEncodedInt(data[offset:]); !null {
offset += intOff // Length of variable length encoded integer itself in bytes
row := data[offset : offset+int(num)]
attrs, err := parseAttrs(row)
if err != nil {
logutil.Logger(ctx).Warn("parse attrs failed", zap.Error(err))
return nil
}
packet.Attrs = attrs
offset += int(num) // Length of attributes
}
}
if packet.Capability&mysql.ClientZstdCompressionAlgorithm > 0 {
packet.ZstdLevel = zstd.EncoderLevelFromZstd(int(data[offset]))
}
return nil
}
func parseAttrs(data []byte) (map[string]string, error) {
attrs := make(map[string]string)
pos := 0
for pos < len(data) {
key, _, off, err := util2.ParseLengthEncodedBytes(data[pos:])
if err != nil {
return attrs, err
}
pos += off
value, _, off, err := util2.ParseLengthEncodedBytes(data[pos:])
if err != nil {
return attrs, err
}
pos += off
attrs[string(key)] = string(value)
}
return attrs, nil
}
func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Context) error {
// Read a packet. It may be a SSLRequest or HandshakeResponse.
data, err := cc.readPacket()
if err != nil {
err = errors.SuspendStack(err)
if errors.Cause(err) == io.EOF {
logutil.Logger(ctx).Debug("wait handshake response fail due to connection has be closed by client-side")
} else {
logutil.Logger(ctx).Debug("wait handshake response fail", zap.Error(err))
}
return err
}
var resp handshakeResponse41
var pos int
if len(data) < 2 {
logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data))
return mysql.ErrMalformPacket
}
capability := uint32(binary.LittleEndian.Uint16(data[:2]))
if capability&mysql.ClientProtocol41 <= 0 {
logutil.Logger(ctx).Error("ClientProtocol41 flag is not set, please upgrade client")
return errNotSupportedAuthMode
}
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
if err != nil {
terror.Log(err)
return err
}
// After read packets we should update the client's host and port to grab
// real client's IP and port from PROXY Protocol header if PROXY Protocol is enabled.
_, _, err = cc.PeerHost("", true)
if err != nil {
terror.Log(err)
return err
}
// If enable proxy protocol check audit plugins after update real IP
if cc.ppEnabled {
err = cc.server.checkAuditPlugin(cc)
if err != nil {
return err
}
}
if resp.Capability&mysql.ClientSSL > 0 {
tlsConfig := (*tls.Config)(atomic.LoadPointer(&cc.server.tlsConfig))
if tlsConfig != nil {
// The packet is a SSLRequest, let's switch to TLS.
if err = cc.upgradeToTLS(tlsConfig); err != nil {
return err
}
// Read the following HandshakeResponse packet.
data, err = cc.readPacket()
if err != nil {
logutil.Logger(ctx).Warn("read handshake response failure after upgrade to TLS", zap.Error(err))
return err
}
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
if err != nil {
terror.Log(err)
return err
}
}
} else if tlsutil.RequireSecureTransport.Load() && !cc.isUnixSocket {
// If it's not a socket connection, we should reject the connection
// because TLS is required.
err := errSecureTransportRequired.FastGenByArgs()
terror.Log(err)
return err
}
// Read the remaining part of the packet.
err = parseHandshakeResponseBody(ctx, &resp, data, pos)
if err != nil {
terror.Log(err)
return err
}
cc.capability = resp.Capability & cc.server.capability
cc.user = resp.User
cc.dbname = resp.DBName
cc.collation = resp.Collation
cc.attrs = resp.Attrs
cc.pkt.zstdLevel = resp.ZstdLevel
err = cc.handleAuthPlugin(ctx, &resp)
if err != nil {
return err
}
switch resp.AuthPlugin {
case mysql.AuthCachingSha2Password:
resp.Auth, err = cc.authSha(ctx, resp)
if err != nil {
return err
}
case mysql.AuthTiDBSM3Password:
resp.Auth, err = cc.authSM3(ctx, resp)
if err != nil {
return err
}
case mysql.AuthNativePassword:
case mysql.AuthSocket:
case mysql.AuthTiDBSessionToken:
case mysql.AuthTiDBAuthToken:
case mysql.AuthMySQLClearPassword:
case mysql.AuthLDAPSASL:
case mysql.AuthLDAPSimple:
default:
return errors.New("Unknown auth plugin")
}
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
if err != nil {
logutil.Logger(ctx).Warn("open new session or authentication failure", zap.Error(err))
}
return err
}
func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeResponse41) error {
if resp.Capability&mysql.ClientPluginAuth > 0 {
newAuth, err := cc.checkAuthPlugin(ctx, resp)
if err != nil {
logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err))
return err
}
if len(newAuth) > 0 {
resp.Auth = newAuth
}
switch resp.AuthPlugin {
case mysql.AuthCachingSha2Password:
case mysql.AuthTiDBSM3Password:
case mysql.AuthNativePassword:
case mysql.AuthSocket:
case mysql.AuthTiDBSessionToken:
case mysql.AuthMySQLClearPassword:
case mysql.AuthLDAPSASL:
case mysql.AuthLDAPSimple:
default:
logutil.Logger(ctx).Warn("Unknown Auth Plugin", zap.String("plugin", resp.AuthPlugin))
}
} else {
// MySQL 5.1 and older clients don't support authentication plugins.
logutil.Logger(ctx).Warn("Client without Auth Plugin support; Please upgrade client")
_, err := cc.checkAuthPlugin(ctx, resp)
if err != nil {
return err
}
resp.AuthPlugin = mysql.AuthNativePassword
}
return nil
}
// authSha implements the caching_sha2_password specific part of the protocol.
func (cc *clientConn) authSha(ctx context.Context, resp handshakeResponse41) ([]byte, error) {
const (
shaCommand = 1
requestRsaPubKey = 2 // Not supported yet, only TLS is supported as secure channel.
fastAuthOk = 3
fastAuthFail = 4
)
// If no password is specified, we don't send the FastAuthFail to do the full authentication
// as that doesn't make sense without a password and confuses the client.
// https://github.com/pingcap/tidb/issues/40831
if len(resp.Auth) == 0 {
return []byte{}, nil
}
// Currently we always send a "FastAuthFail" as the cached part of the protocol isn't implemented yet.
// This triggers the client to send the full response.
err := cc.writePacket([]byte{0, 0, 0, 0, shaCommand, fastAuthFail})
if err != nil {
logutil.Logger(ctx).Error("authSha packet write failed", zap.Error(err))
return nil, err
}
err = cc.flush(ctx)
if err != nil {
logutil.Logger(ctx).Error("authSha packet flush failed", zap.Error(err))
return nil, err
}
data, err := cc.readPacket()
if err != nil {
logutil.Logger(ctx).Error("authSha packet read failed", zap.Error(err))
return nil, err
}
return bytes.Trim(data, "\x00"), nil
}
// authSM3 implements the tidb_sm3_password specific part of the protocol.
// tidb_sm3_password is very similar to caching_sha2_password.
func (cc *clientConn) authSM3(ctx context.Context, resp handshakeResponse41) ([]byte, error) {
// If no password is specified, we don't send the FastAuthFail to do the full authentication
// as that doesn't make sense without a password and confuses the client.
// https://github.com/pingcap/tidb/issues/40831
if len(resp.Auth) == 0 {
return []byte{}, nil
}
err := cc.writePacket([]byte{0, 0, 0, 0, 1, 4}) // fastAuthFail
if err != nil {
logutil.Logger(ctx).Error("authSM3 packet write failed", zap.Error(err))
return nil, err
}
err = cc.flush(ctx)
if err != nil {
logutil.Logger(ctx).Error("authSM3 packet flush failed", zap.Error(err))
return nil, err
}
data, err := cc.readPacket()
if err != nil {
logutil.Logger(ctx).Error("authSM3 packet read failed", zap.Error(err))
return nil, err
}
return bytes.Trim(data, "\x00"), nil
}
func (cc *clientConn) SessionStatusToString() string {
status := cc.ctx.Status()
inTxn, autoCommit := 0, 0
if status&mysql.ServerStatusInTrans > 0 {
inTxn = 1
}
if status&mysql.ServerStatusAutocommit > 0 {
autoCommit = 1
}
return fmt.Sprintf("inTxn:%d, autocommit:%d",
inTxn, autoCommit,
)
}
func (cc *clientConn) openSession() error {
var tlsStatePtr *tls.ConnectionState
if cc.tlsConn != nil {
tlsState := cc.tlsConn.ConnectionState()
tlsStatePtr = &tlsState
}
ctx, err := cc.server.driver.OpenCtx(cc.connectionID, cc.capability, cc.collation, cc.dbname, tlsStatePtr, cc.extensions)
if err != nil {
return err
}
cc.setCtx(ctx)
err = cc.server.checkConnectionCount()
if err != nil {
return err
}
return nil
}
func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) error {
// Open a context unless this was done before.
if ctx := cc.getCtx(); ctx == nil {
err := cc.openSession()
if err != nil {
return err
}
}
hasPassword := "YES"
if len(authData) == 0 {
hasPassword = "NO"
}
host, port, err := cc.PeerHost(hasPassword, false)
if err != nil {
return err
}
if !cc.isUnixSocket && authPlugin == mysql.AuthSocket {
return errAccessDeniedNoPassword.FastGenByArgs(cc.user, host)
}
userIdentity := &auth.UserIdentity{Username: cc.user, Hostname: host, AuthPlugin: authPlugin}
if err = cc.ctx.Auth(userIdentity, authData, cc.salt, cc); err != nil {
return err
}
cc.ctx.SetPort(port)
if cc.dbname != "" {
_, err = cc.useDB(context.Background(), cc.dbname)
if err != nil {
return err
}
}
cc.ctx.SetSessionManager(cc.server)
return nil
}
// Check if the Authentication Plugin of the server, client and user configuration matches
func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeResponse41) ([]byte, error) {
// Open a context unless this was done before.
if ctx := cc.getCtx(); ctx == nil {
err := cc.openSession()
if err != nil {
return nil, err
}
}
authData := resp.Auth
// tidb_session_token is always permitted and skips stored user plugin.
if resp.AuthPlugin == mysql.AuthTiDBSessionToken {
return authData, nil
}
hasPassword := "YES"
if len(authData) == 0 {
hasPassword = "NO"
}
host, _, err := cc.PeerHost(hasPassword, false)
if err != nil {
return nil, err
}
// Find the identity of the user based on username and peer host.
identity, err := cc.ctx.MatchIdentity(cc.user, host)
if err != nil {
return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
// Get the plugin for the identity.
userplugin, err := cc.ctx.AuthPluginForUser(identity)
if err != nil {
logutil.Logger(ctx).Warn("Failed to get authentication method for user",
zap.String("user", cc.user), zap.String("host", host))
}
failpoint.Inject("FakeUser", func(val failpoint.Value) {
//nolint:forcetypeassert
userplugin = val.(string)
})
if userplugin == mysql.AuthSocket {
if !cc.isUnixSocket {
return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
resp.AuthPlugin = mysql.AuthSocket
user, err := user.LookupId(fmt.Sprint(cc.socketCredUID))
if err != nil {
return nil, err
}
return []byte(user.Username), nil
}
if len(userplugin) == 0 {
// No user plugin set, assuming MySQL Native Password
// This happens if the account doesn't exist or if the account doesn't have
// a password set.
if resp.AuthPlugin != mysql.AuthNativePassword {
if resp.Capability&mysql.ClientPluginAuth > 0 {
resp.AuthPlugin = mysql.AuthNativePassword
authData, err := cc.authSwitchRequest(ctx, mysql.AuthNativePassword)
if err != nil {
return nil, err
}
return authData, nil
}
}
return nil, nil
}
// If the authentication method send by the server (cc.authPlugin) doesn't match
// the plugin configured for the user account in the mysql.user.plugin column
// or if the authentication method send by the server doesn't match the authentication
// method send by the client (*authPlugin) then we need to switch the authentication
// method to match the one configured for that specific user.
if (cc.authPlugin != userplugin) || (cc.authPlugin != resp.AuthPlugin) {
if userplugin == mysql.AuthTiDBAuthToken {
userplugin = mysql.AuthMySQLClearPassword
}
if resp.Capability&mysql.ClientPluginAuth > 0 {
authData, err := cc.authSwitchRequest(ctx, userplugin)
if err != nil {
return nil, err
}
resp.AuthPlugin = userplugin
return authData, nil
} else if userplugin != mysql.AuthNativePassword {
// MySQL 5.1 and older don't support authentication plugins yet
return nil, errNotSupportedAuthMode
}
}
return nil, nil
}
func (cc *clientConn) PeerHost(hasPassword string, update bool) (host, port string, err error) {
// already get peer host
if len(cc.peerHost) > 0 {
// Proxy protocol enabled and not update
if cc.ppEnabled && !update {
return cc.peerHost, cc.peerPort, nil
}
// Proxy protocol not enabled
if !cc.ppEnabled {
return cc.peerHost, cc.peerPort, nil
}
}
host = variable.DefHostname
if cc.isUnixSocket {
cc.peerHost = host
cc.serverHost = host
return
}
addr := cc.bufReadConn.RemoteAddr().String()
host, port, err = net.SplitHostPort(addr)
if err != nil {
err = errAccessDenied.GenWithStackByArgs(cc.user, addr, hasPassword)
return
}
cc.peerHost = host
cc.peerPort = port
serverAddr := cc.bufReadConn.LocalAddr().String()
serverHost, _, err := net.SplitHostPort(serverAddr)
if err != nil {
err = errAccessDenied.GenWithStackByArgs(cc.user, addr, hasPassword)
return
}
cc.serverHost = serverHost
return
}
// skipInitConnect follows MySQL's rules of when init-connect should be skipped.
// In 5.7 it is any user with SUPER privilege, but in 8.0 it is:
// - SUPER or the CONNECTION_ADMIN dynamic privilege.
// - (additional exception) users with expired passwords (not yet supported)
// In TiDB CONNECTION_ADMIN is satisfied by SUPER, so we only need to check once.
func (cc *clientConn) skipInitConnect() bool {
checker := privilege.GetPrivilegeManager(cc.ctx.Session)
activeRoles := cc.ctx.GetSessionVars().ActiveRoles
return checker != nil && checker.RequestDynamicVerification(activeRoles, "CONNECTION_ADMIN", false)
}
// initResultEncoder initialize the result encoder for current connection.
func (cc *clientConn) initResultEncoder(ctx context.Context) {
chs, err := cc.ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), variable.CharacterSetResults)
if err != nil {
chs = ""
logutil.Logger(ctx).Warn("get character_set_results system variable failed", zap.Error(err))
}
cc.rsEncoder = column.NewResultEncoder(chs)
}
func (cc *clientConn) initInputEncoder(ctx context.Context) {
chs, err := cc.ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), variable.CharacterSetClient)
if err != nil {
chs = ""
logutil.Logger(ctx).Warn("get character_set_client system variable failed", zap.Error(err))
}
cc.inputDecoder = util2.NewInputDecoder(chs)
}
// initConnect runs the initConnect SQL statement if it has been specified.
// The semantics are MySQL compatible.
func (cc *clientConn) initConnect(ctx context.Context) error {
val, err := cc.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.InitConnect)
if err != nil {
return err
}
if val == "" || cc.skipInitConnect() {
return nil
}
logutil.Logger(ctx).Debug("init_connect starting")
stmts, err := cc.ctx.Parse(ctx, val)
if err != nil {
return err
}
for _, stmt := range stmts {
rs, err := cc.ctx.ExecuteStmt(ctx, stmt)
if err != nil {
return err
}
// init_connect does not care about the results,
// but they need to be drained because of lazy loading.
if rs != nil {
req := rs.NewChunk(nil)
for {
if err = rs.Next(ctx, req); err != nil {
return err
}
if req.NumRows() == 0 {
break
}
}
if err := rs.Close(); err != nil {
return err
}
}
}
logutil.Logger(ctx).Debug("init_connect complete")
return nil
}
// Run reads client query and writes query result to client in for loop, if there is a panic during query handling,
// it will be recovered and log the panic error.
// This function returns and the connection is closed if there is an IO error or there is a panic.
func (cc *clientConn) Run(ctx context.Context) {
defer func() {
r := recover()
if r != nil {
logutil.Logger(ctx).Error("connection running loop panic",
zap.Stringer("lastSQL", getLastStmtInConn{cc}),
zap.String("err", fmt.Sprintf("%v", r)),
zap.Stack("stack"),
)
err := cc.writeError(ctx, fmt.Errorf("%v", r))
terror.Log(err)
metrics.PanicCounter.WithLabelValues(metrics.LabelSession).Inc()
}
if cc.getStatus() != connStatusShutdown {
err := cc.Close()
terror.Log(err)
}
close(cc.quit)
}()
// Usually, client connection status changes between [dispatching] <=> [reading].
// When some event happens, server may notify this client connection by setting
// the status to special values, for example: kill or graceful shutdown.
// The client connection would detect the events when it fails to change status
// by CAS operation, it would then take some actions accordingly.
for {
// Close connection between txn when we are going to shutdown server.
// Note the current implementation when shutting down, for an idle connection, the connection may block at readPacket()
// consider provider a way to close the connection directly after sometime if we can not read any data.
if cc.server.inShutdownMode.Load() {
if !cc.ctx.GetSessionVars().InTxn() {
return
}
}
if !cc.CompareAndSwapStatus(connStatusDispatching, connStatusReading) ||
// The judge below will not be hit by all means,
// But keep it stayed as a reminder and for the code reference for connStatusWaitShutdown.
cc.getStatus() == connStatusWaitShutdown {
return
}
cc.alloc.Reset()
// close connection when idle time is more than wait_timeout
// default 28800(8h), FIXME: should not block at here when we kill the connection.
waitTimeout := cc.getSessionVarsWaitTimeout(ctx)
cc.pkt.setReadTimeout(time.Duration(waitTimeout) * time.Second)
start := time.Now()
data, err := cc.readPacket()
if err != nil {
if terror.ErrorNotEqual(err, io.EOF) {
if netErr, isNetErr := errors.Cause(err).(net.Error); isNetErr && netErr.Timeout() {
if cc.getStatus() == connStatusWaitShutdown {
logutil.Logger(ctx).Info("read packet timeout because of killed connection")
} else {
idleTime := time.Since(start)
logutil.Logger(ctx).Info("read packet timeout, close this connection",
zap.Duration("idle", idleTime),
zap.Uint64("waitTimeout", waitTimeout),
zap.Error(err),
)
}
} else if errors.ErrorEqual(err, errNetPacketTooLarge) {
err := cc.writeError(ctx, err)
if err != nil {
terror.Log(err)
}
} else {
errStack := errors.ErrorStack(err)
if !strings.Contains(errStack, "use of closed network connection") {
logutil.Logger(ctx).Warn("read packet failed, close this connection",
zap.Error(errors.SuspendStack(err)))
}
}
}
server_metrics.DisconnectByClientWithError.Inc()
return
}
// Should check InTxn() to avoid execute `begin` stmt.
if cc.server.inShutdownMode.Load() {
if !cc.ctx.GetSessionVars().InTxn() {
return
}
}
if !cc.CompareAndSwapStatus(connStatusReading, connStatusDispatching) {
return
}
startTime := time.Now()
err = cc.dispatch(ctx, data)
cc.ctx.GetSessionVars().ClearAlloc(&cc.chunkAlloc, err != nil)
cc.chunkAlloc.Reset()
if err != nil {
cc.audit(plugin.Error) // tell the plugin API there was a dispatch error
if terror.ErrorEqual(err, io.EOF) {
cc.addMetrics(data[0], startTime, nil)
server_metrics.DisconnectNormal.Inc()
return
} else if terror.ErrResultUndetermined.Equal(err) {
logutil.Logger(ctx).Error("result undetermined, close this connection", zap.Error(err))
server_metrics.DisconnectErrorUndetermined.Inc()
return
} else if terror.ErrCritical.Equal(err) {
metrics.CriticalErrorCounter.Add(1)
logutil.Logger(ctx).Fatal("critical error, stop the server", zap.Error(err))
}
var txnMode string
if ctx := cc.getCtx(); ctx != nil {
txnMode = ctx.GetSessionVars().GetReadableTxnMode()
}
for _, dbName := range session.GetDBNames(cc.getCtx().GetSessionVars()) {
metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err), dbName).Inc()
}
if storeerr.ErrLockAcquireFailAndNoWaitSet.Equal(err) {
logutil.Logger(ctx).Debug("Expected error for FOR UPDATE NOWAIT", zap.Error(err))
} else {
var startTS uint64
if ctx := cc.getCtx(); ctx != nil && ctx.GetSessionVars() != nil && ctx.GetSessionVars().TxnCtx != nil {
startTS = ctx.GetSessionVars().TxnCtx.StartTS
}
logutil.Logger(ctx).Info("command dispatched failed",
zap.String("connInfo", cc.String()),
zap.String("command", mysql.Command2Str[data[0]]),
zap.String("status", cc.SessionStatusToString()),
zap.Stringer("sql", getLastStmtInConn{cc}),
zap.String("txn_mode", txnMode),
zap.Uint64("timestamp", startTS),
zap.String("err", errStrForLog(err, cc.ctx.GetSessionVars().EnableRedactLog)),
)
}
err1 := cc.writeError(ctx, err)
terror.Log(err1)
}
cc.addMetrics(data[0], startTime, err)
cc.pkt.sequence = 0
cc.pkt.compressedSequence = 0
}
}
func errStrForLog(err error, enableRedactLog bool) string {
if enableRedactLog {
// currently, only ErrParse is considered when enableRedactLog because it may contain sensitive information like
// password or accesskey
if parser.ErrParse.Equal(err) {
return "fail to parse SQL and can't redact when enable log redaction"
}
}
if kv.ErrKeyExists.Equal(err) || parser.ErrParse.Equal(err) || infoschema.ErrTableNotExists.Equal(err) {
// Do not log stack for duplicated entry error.
return err.Error()
}
return errors.ErrorStack(err)
}
func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) {
if cmd == mysql.ComQuery && cc.ctx.Value(sessionctx.LastExecuteDDL) != nil {
// Don't take DDL execute time into account.
// It's already recorded by other metrics in ddl package.
return
}
var counter prometheus.Counter
if err != nil && int(cmd) < len(server_metrics.QueryTotalCountErr) {
counter = server_metrics.QueryTotalCountErr[cmd]
} else if err == nil && int(cmd) < len(server_metrics.QueryTotalCountOk) {
counter = server_metrics.QueryTotalCountOk[cmd]
}
if counter != nil {
counter.Inc()
} else {
label := strconv.Itoa(int(cmd))
if err != nil {
metrics.QueryTotalCounter.WithLabelValues(label, "Error").Inc()
} else {
metrics.QueryTotalCounter.WithLabelValues(label, "OK").Inc()
}
}
cost := time.Since(startTime)
sessionVar := cc.ctx.GetSessionVars()
affectedRows := cc.ctx.AffectedRows()
cc.ctx.GetTxnWriteThroughputSLI().FinishExecuteStmt(cost, affectedRows, sessionVar.InTxn())
stmtType := sessionVar.StmtCtx.StmtType
sqlType := metrics.LblGeneral
if stmtType != "" {
sqlType = stmtType
}
switch sqlType {
case "Insert":
server_metrics.AffectedRowsCounterInsert.Add(float64(affectedRows))
case "Replace":
server_metrics.AffectedRowsCounterReplace.Add(float64(affectedRows))
case "Delete":
server_metrics.AffectedRowsCounterDelete.Add(float64(affectedRows))
case "Update":
server_metrics.AffectedRowsCounterUpdate.Add(float64(affectedRows))
}
vars := cc.getCtx().GetSessionVars()
for _, dbName := range session.GetDBNames(vars) {
metrics.QueryDurationHistogram.WithLabelValues(sqlType, dbName, vars.ResourceGroupName).Observe(cost.Seconds())
}
}
// dispatch handles client request based on command which is the first byte of the data.
// It also gets a token from server which is used to limit the concurrently handling clients.
// The most frequently used command is ComQuery.
func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
defer func() {
// reset killed for each request
atomic.StoreUint32(&cc.ctx.GetSessionVars().Killed, 0)
}()
t := time.Now()
if (cc.ctx.Status() & mysql.ServerStatusInTrans) > 0 {
server_metrics.ConnIdleDurationHistogramInTxn.Observe(t.Sub(cc.lastActive).Seconds())
} else {
server_metrics.ConnIdleDurationHistogramNotInTxn.Observe(t.Sub(cc.lastActive).Seconds())
}
cfg := config.GetGlobalConfig()
if cfg.OpenTracing.Enable {
var r tracing.Region
r, ctx = tracing.StartRegionEx(ctx, "server.dispatch")
defer r.End()
}
var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithCancel(ctx)
cc.mu.Lock()
cc.mu.cancelFunc = cancelFunc
cc.mu.Unlock()
cc.lastPacket = data
cmd := data[0]
data = data[1:]
if topsqlstate.TopSQLEnabled() {
defer pprof.SetGoroutineLabels(ctx)
}
if variable.EnablePProfSQLCPU.Load() {
label := getLastStmtInConn{cc}.PProfLabel()
if len(label) > 0 {
defer pprof.SetGoroutineLabels(ctx)
ctx = pprof.WithLabels(ctx, pprof.Labels("sql", label))
pprof.SetGoroutineLabels(ctx)
}
}
if trace.IsEnabled() {
lc := getLastStmtInConn{cc}
sqlType := lc.PProfLabel()
if len(sqlType) > 0 {
var task *trace.Task
ctx, task = trace.NewTask(ctx, sqlType)
defer task.End()
trace.Log(ctx, "sql", lc.String())
ctx = logutil.WithTraceLogger(ctx, cc.connectionID)
taskID := *(*uint64)(unsafe.Pointer(task))
ctx = pprof.WithLabels(ctx, pprof.Labels("trace", strconv.FormatUint(taskID, 10)))
pprof.SetGoroutineLabels(ctx)
}
}
token := cc.server.getToken()
defer func() {
// if handleChangeUser failed, cc.ctx may be nil
if ctx := cc.getCtx(); ctx != nil {
ctx.SetProcessInfo("", t, mysql.ComSleep, 0)
}
cc.server.releaseToken(token)
cc.lastActive = time.Now()
}()
vars := cc.ctx.GetSessionVars()
// reset killed for each request
atomic.StoreUint32(&vars.Killed, 0)
if cmd < mysql.ComEnd {
cc.ctx.SetCommandValue(cmd)
}
dataStr := string(hack.String(data))
switch cmd {
case mysql.ComPing, mysql.ComStmtClose, mysql.ComStmtSendLongData, mysql.ComStmtReset,
mysql.ComSetOption, mysql.ComChangeUser:
cc.ctx.SetProcessInfo("", t, cmd, 0)
case mysql.ComInitDB:
cc.ctx.SetProcessInfo("use "+dataStr, t, cmd, 0)
}
switch cmd {
case mysql.ComQuit:
return io.EOF
case mysql.ComInitDB:
node, err := cc.useDB(ctx, dataStr)
cc.onExtensionStmtEnd(node, false, err)
if err != nil {
return err
}
return cc.writeOK(ctx)
case mysql.ComQuery: // Most frequently used command.
// For issue 1989
// Input payload may end with byte '\0', we didn't find related mysql document about it, but mysql
// implementation accept that case. So trim the last '\0' here as if the payload an EOF string.
// See http://dev.mysql.com/doc/internals/en/com-query.html
if len(data) > 0 && data[len(data)-1] == 0 {
data = data[:len(data)-1]
dataStr = string(hack.String(data))
}
return cc.handleQuery(ctx, dataStr)
case mysql.ComFieldList:
return cc.handleFieldList(ctx, dataStr)
// ComCreateDB, ComDropDB
case mysql.ComRefresh:
return cc.handleRefresh(ctx, data[0])
case mysql.ComShutdown: // redirect to SQL
if err := cc.handleQuery(ctx, "SHUTDOWN"); err != nil {
return err
}
return cc.writeOK(ctx)
case mysql.ComStatistics:
return cc.writeStats(ctx)
// ComProcessInfo, ComConnect, ComProcessKill, ComDebug
case mysql.ComPing:
return cc.writeOK(ctx)
case mysql.ComChangeUser:
return cc.handleChangeUser(ctx, data)
// ComBinlogDump, ComTableDump, ComConnectOut, ComRegisterSlave
case mysql.ComStmtPrepare:
// For issue 39132, same as ComQuery
if len(data) > 0 && data[len(data)-1] == 0 {
data = data[:len(data)-1]
dataStr = string(hack.String(data))
}
return cc.handleStmtPrepare(ctx, dataStr)
case mysql.ComStmtExecute:
return cc.handleStmtExecute(ctx, data)
case mysql.ComStmtSendLongData:
return cc.handleStmtSendLongData(data)
case mysql.ComStmtClose:
return cc.handleStmtClose(data)
case mysql.ComStmtReset:
return cc.handleStmtReset(ctx, data)
case mysql.ComSetOption:
return cc.handleSetOption(ctx, data)
case mysql.ComStmtFetch:
return cc.handleStmtFetch(ctx, data)
// ComDaemon, ComBinlogDumpGtid
case mysql.ComResetConnection:
return cc.handleResetConnection(ctx)
// ComEnd
default:
return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", nil, cmd)
}
}
func (cc *clientConn) writeStats(ctx context.Context) error {
var err error
var uptime int64
info := serverInfo{}
info.ServerInfo, err = infosync.GetServerInfo()
if err != nil {
logutil.BgLogger().Error("Failed to get ServerInfo for uptime status", zap.Error(err))
} else {
uptime = int64(time.Since(time.Unix(info.ServerInfo.StartTimestamp, 0)).Seconds())
}
msg := []byte(fmt.Sprintf("Uptime: %d Threads: 0 Questions: 0 Slow queries: 0 Opens: 0 Flush tables: 0 Open tables: 0 Queries per second avg: 0.000",
uptime))
data := cc.alloc.AllocWithLen(4, len(msg))
data = append(data, msg...)
err = cc.writePacket(data)
if err != nil {
return err
}
return cc.flush(ctx)
}
func (cc *clientConn) useDB(ctx context.Context, db string) (node ast.StmtNode, err error) {
// if input is "use `SELECT`", mysql client just send "SELECT"
// so we add `` around db.
stmts, err := cc.ctx.Parse(ctx, "use `"+db+"`")
if err != nil {
return nil, err
}
_, err = cc.ctx.ExecuteStmt(ctx, stmts[0])
if err != nil {
return stmts[0], err
}
cc.dbname = db
return stmts[0], err
}
func (cc *clientConn) flush(ctx context.Context) error {
var (
stmtDetail *execdetails.StmtExecDetails
startTime time.Time
)
if stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey); stmtDetailRaw != nil {
//nolint:forcetypeassert
stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails)
startTime = time.Now()
}
defer func() {
if stmtDetail != nil {
stmtDetail.WriteSQLRespDuration += time.Since(startTime)
}
trace.StartRegion(ctx, "FlushClientConn").End()
if ctx := cc.getCtx(); ctx != nil && ctx.WarningCount() > 0 {
for _, err := range ctx.GetWarnings() {
var warn *errors.Error
if ok := goerr.As(err.Err, &warn); ok {
code := uint16(warn.Code())
errno.IncrementWarning(code, cc.user, cc.peerHost)
}
}
}
}()
failpoint.Inject("FakeClientConn", func() {
if cc.pkt == nil {
failpoint.Return(nil)
}
})
return cc.pkt.flush()
}
func (cc *clientConn) writeOK(ctx context.Context) error {
return cc.writeOkWith(ctx, mysql.OKHeader, true, cc.ctx.Status())
}
func (cc *clientConn) writeOkWith(ctx context.Context, header byte, flush bool, status uint16) error {
msg := cc.ctx.LastMessage()
affectedRows := cc.ctx.AffectedRows()
lastInsertID := cc.ctx.LastInsertID()
warnCnt := cc.ctx.WarningCount()
enclen := 0
if len(msg) > 0 {
enclen = util2.LengthEncodedIntSize(uint64(len(msg))) + len(msg)
}
data := cc.alloc.AllocWithLen(4, 32+enclen)
data = append(data, header)
data = dump.LengthEncodedInt(data, affectedRows)
data = dump.LengthEncodedInt(data, lastInsertID)
if cc.capability&mysql.ClientProtocol41 > 0 {
data = dump.Uint16(data, status)
data = dump.Uint16(data, warnCnt)
}
if enclen > 0 {
// although MySQL manual says the info message is string<EOF>(https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html),
// it is actually string<lenenc>
data = dump.LengthEncodedString(data, []byte(msg))
}
err := cc.writePacket(data)
if err != nil {
return err
}
if flush {
return cc.flush(ctx)
}
return nil
}
func (cc *clientConn) writeError(ctx context.Context, e error) error {
var (
m *mysql.SQLError
te *terror.Error
ok bool
)
originErr := errors.Cause(e)
if te, ok = originErr.(*terror.Error); ok {
m = terror.ToSQLError(te)
} else {
e := errors.Cause(originErr)
switch y := e.(type) {
case *terror.Error:
m = terror.ToSQLError(y)
default:
m = mysql.NewErrf(mysql.ErrUnknown, "%s", nil, e.Error())
}
}
cc.lastCode = m.Code
defer errno.IncrementError(m.Code, cc.user, cc.peerHost)
data := cc.alloc.AllocWithLen(4, 16+len(m.Message))
data = append(data, mysql.ErrHeader)
data = append(data, byte(m.Code), byte(m.Code>>8))
if cc.capability&mysql.ClientProtocol41 > 0 {
data = append(data, '#')
data = append(data, m.State...)
}
data = append(data, m.Message...)
err := cc.writePacket(data)
if err != nil {
return err
}
return cc.flush(ctx)
}
// writeEOF writes an EOF packet or if ClientDeprecateEOF is set it
// writes an OK packet with EOF indicator.
// Note this function won't flush the stream because maybe there are more
// packets following it.
// serverStatus, a flag bit represents server information in the packet.
// Note: it is callers' responsibility to ensure correctness of serverStatus.
func (cc *clientConn) writeEOF(ctx context.Context, serverStatus uint16) error {
if cc.capability&mysql.ClientDeprecateEOF > 0 {
return cc.writeOkWith(ctx, mysql.EOFHeader, false, serverStatus)
}
data := cc.alloc.AllocWithLen(4, 9)
data = append(data, mysql.EOFHeader)
if cc.capability&mysql.ClientProtocol41 > 0 {
data = dump.Uint16(data, cc.ctx.WarningCount())
data = dump.Uint16(data, serverStatus)
}
err := cc.writePacket(data)
return err
}
func (cc *clientConn) writeReq(ctx context.Context, filePath string) error {
data := cc.alloc.AllocWithLen(4, 5+len(filePath))
data = append(data, mysql.LocalInFileHeader)
data = append(data, filePath...)
err := cc.writePacket(data)
if err != nil {
return err
}
return cc.flush(ctx)
}
// handleLoadData does the additional work after processing the 'load data' query.
// It sends client a file path, then reads the file content from client, inserts data into database.
func (cc *clientConn) handleLoadData(ctx context.Context, loadDataWorker *executor.LoadDataWorker) error {
// If the server handles the load data request, the client has to set the ClientLocalFiles capability.
if cc.capability&mysql.ClientLocalFiles == 0 {
return errNotAllowedCommand
}
if loadDataWorker == nil {
return errors.New("load data info is empty")
}
infile := loadDataWorker.GetInfilePath()
err := cc.writeReq(ctx, infile)
if err != nil {
return err
}
var (
// use Pipe to convert cc.readPacket to io.Reader
r, w = io.Pipe()
drained bool
wg sync.WaitGroup
)
wg.Add(1)
go func() {
defer wg.Done()
//nolint: errcheck
defer w.Close()
var (
data []byte
err2 error
)
for {
if len(data) == 0 {
data, err2 = cc.readPacket()
if err2 != nil {
w.CloseWithError(err2)
return
}
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html
if len(data) == 0 {
drained = true
return
}
}
n, err3 := w.Write(data)
if err3 != nil {
logutil.Logger(ctx).Error("write data meet error", zap.Error(err3))
return
}
data = data[n:]
}
}()
ctx = kv.WithInternalSourceType(ctx, kv.InternalLoadData)
err = loadDataWorker.LoadLocal(ctx, r)
_ = r.Close()
wg.Wait()
if err != nil {
if !drained {
logutil.Logger(ctx).Info("not drained yet, try reading left data from client connection")
}
// drain the data from client conn util empty packet received, otherwise the connection will be reset
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html
for !drained {
// check kill flag again, let the draining loop could quit if empty packet could not be received
if atomic.CompareAndSwapUint32(&loadDataWorker.UserSctx.GetSessionVars().Killed, 1, 0) {
logutil.Logger(ctx).Warn("receiving kill, stop draining data, connection may be reset")
return exeerrors.ErrQueryInterrupted
}
curData, err1 := cc.readPacket()
if err1 != nil {
logutil.Logger(ctx).Error("drain reading left data encounter errors", zap.Error(err1))
break
}
if len(curData) == 0 {
drained = true
logutil.Logger(ctx).Info("draining finished for error", zap.Error(err))
break
}
}
}
return err
}
// getDataFromPath gets file contents from file path.
func (cc *clientConn) getDataFromPath(ctx context.Context, path string) ([]byte, error) {
err := cc.writeReq(ctx, path)
if err != nil {
return nil, err
}
var prevData, curData []byte
for {
curData, err = cc.readPacket()
if err != nil && terror.ErrorNotEqual(err, io.EOF) {
return nil, err
}
if len(curData) == 0 {
break
}
prevData = append(prevData, curData...)
}
return prevData, nil
}
// handleLoadStats does the additional work after processing the 'load stats' query.
// It sends client a file path, then reads the file content from client, loads it into the storage.
func (cc *clientConn) handleLoadStats(ctx context.Context, loadStatsInfo *executor.LoadStatsInfo) error {
// If the server handles the load data request, the client has to set the ClientLocalFiles capability.
if cc.capability&mysql.ClientLocalFiles == 0 {
return errNotAllowedCommand
}
if loadStatsInfo == nil {
return errors.New("load stats: info is empty")
}
data, err := cc.getDataFromPath(ctx, loadStatsInfo.Path)
if err != nil {
return err
}
if len(data) == 0 {
return nil
}
return loadStatsInfo.Update(data)
}
// handleIndexAdvise does the index advise work and returns the advise result for index.
func (cc *clientConn) handleIndexAdvise(ctx context.Context, indexAdviseInfo *executor.IndexAdviseInfo) error {
if cc.capability&mysql.ClientLocalFiles == 0 {
return errNotAllowedCommand
}
if indexAdviseInfo == nil {
return errors.New("Index Advise: info is empty")
}
data, err := cc.getDataFromPath(ctx, indexAdviseInfo.Path)
if err != nil {
return err
}
if len(data) == 0 {
return errors.New("Index Advise: infile is empty")
}
if err := indexAdviseInfo.GetIndexAdvice(ctx, data); err != nil {
return err
}
// TODO: Write the rss []ResultSet. It will be done in another PR.
return nil
}
func (cc *clientConn) handlePlanReplayerLoad(ctx context.Context, planReplayerLoadInfo *executor.PlanReplayerLoadInfo) error {
if cc.capability&mysql.ClientLocalFiles == 0 {
return errNotAllowedCommand
}
if planReplayerLoadInfo == nil {
return errors.New("plan replayer load: info is empty")
}
data, err := cc.getDataFromPath(ctx, planReplayerLoadInfo.Path)
if err != nil {
return err
}
if len(data) == 0 {
return nil
}
return planReplayerLoadInfo.Update(data)
}
func (cc *clientConn) handlePlanReplayerDump(ctx context.Context, e *executor.PlanReplayerDumpInfo) error {
if cc.capability&mysql.ClientLocalFiles == 0 {
return errNotAllowedCommand
}
if e == nil {
return errors.New("plan replayer dump: executor is empty")
}
data, err := cc.getDataFromPath(ctx, e.Path)
if err != nil {
logutil.BgLogger().Error(err.Error())
return err
}
if len(data) == 0 {
return nil
}
return e.DumpSQLsFromFile(ctx, data)
}
func (cc *clientConn) audit(eventType plugin.GeneralEvent) {
err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
audit := plugin.DeclareAuditManifest(p.Manifest)
if audit.OnGeneralEvent != nil {
cmd := mysql.Command2Str[byte(atomic.LoadUint32(&cc.ctx.GetSessionVars().CommandValue))]
ctx := context.WithValue(context.Background(), plugin.ExecStartTimeCtxKey, cc.ctx.GetSessionVars().StartTime)
audit.OnGeneralEvent(ctx, cc.ctx.GetSessionVars(), eventType, cmd)
}
return nil
})
if err != nil {
terror.Log(err)
}
}
// handleQuery executes the sql query string and writes result set or result ok to the client.
// As the execution time of this function represents the performance of TiDB, we do time log and metrics here.
// Some special queries like `load data` that does not return result, which is handled in handleFileTransInConn.
func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) {
defer trace.StartRegion(ctx, "handleQuery").End()
sessVars := cc.ctx.GetSessionVars()
sc := sessVars.StmtCtx
prevWarns := sc.GetWarnings()
var stmts []ast.StmtNode
cc.ctx.GetSessionVars().SetAlloc(cc.chunkAlloc)
if stmts, err = cc.ctx.Parse(ctx, sql); err != nil {
cc.onExtensionSQLParseFailed(sql, err)
return err
}
if len(stmts) == 0 {
return cc.writeOK(ctx)
}
warns := sc.GetWarnings()
parserWarns := warns[len(prevWarns):]
var pointPlans []plannercore.Plan
cc.ctx.GetSessionVars().InMultiStmts = false
if len(stmts) > 1 {
// The client gets to choose if it allows multi-statements, and
// probably defaults OFF. This helps prevent against SQL injection attacks
// by early terminating the first statement, and then running an entirely
// new statement.
capabilities := cc.ctx.GetSessionVars().ClientCapability
if capabilities&mysql.ClientMultiStatements < 1 {
// The client does not have multi-statement enabled. We now need to determine
// how to handle an unsafe situation based on the multiStmt sysvar.
switch cc.ctx.GetSessionVars().MultiStatementMode {
case variable.OffInt:
err = errMultiStatementDisabled
return err
case variable.OnInt:
// multi statement is fully permitted, do nothing
default:
warn := stmtctx.SQLWarn{Level: stmtctx.WarnLevelWarning, Err: errMultiStatementDisabled}
parserWarns = append(parserWarns, warn)
}
}
cc.ctx.GetSessionVars().InMultiStmts = true
// Only pre-build point plans for multi-statement query
pointPlans, err = cc.prefetchPointPlanKeys(ctx, stmts)
if err != nil {
for _, stmt := range stmts {
cc.onExtensionStmtEnd(stmt, false, err)
}
return err
}
metrics.NumOfMultiQueryHistogram.Observe(float64(len(stmts)))
}
if len(pointPlans) > 0 {
defer cc.ctx.ClearValue(plannercore.PointPlanKey)
}
var retryable bool
var lastStmt ast.StmtNode
var expiredStmtTaskID uint64
for i, stmt := range stmts {
if lastStmt != nil {
cc.onExtensionStmtEnd(lastStmt, true, nil)
}
lastStmt = stmt
// expiredTaskID is the task ID of the previous statement. When executing a stmt,
// the StmtCtx will be reinit and the TaskID will change. We can compare the StmtCtx.TaskID
// with the previous one to determine whether StmtCtx has been inited for the current stmt.
expiredStmtTaskID = sessVars.StmtCtx.TaskID
if len(pointPlans) > 0 {
// Save the point plan in Session, so we don't need to build the point plan again.
cc.ctx.SetValue(plannercore.PointPlanKey, plannercore.PointPlanVal{Plan: pointPlans[i]})
}
retryable, err = cc.handleStmt(ctx, stmt, parserWarns, i == len(stmts)-1)
if err != nil {
action, txnErr := sessiontxn.GetTxnManager(&cc.ctx).OnStmtErrorForNextAction(ctx, sessiontxn.StmtErrAfterQuery, err)
if txnErr != nil {
err = txnErr
break
}
if retryable && action == sessiontxn.StmtActionRetryReady {
cc.ctx.GetSessionVars().RetryInfo.Retrying = true
_, err = cc.handleStmt(ctx, stmt, parserWarns, i == len(stmts)-1)
cc.ctx.GetSessionVars().RetryInfo.Retrying = false
if err != nil {
break
}
continue
}
if !retryable || !errors.ErrorEqual(err, storeerr.ErrTiFlashServerTimeout) {
break
}
_, allowTiFlashFallback := cc.ctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash]
if !allowTiFlashFallback {
break
}
// When the TiFlash server seems down, we append a warning to remind the user to check the status of the TiFlash
// server and fallback to TiKV.
warns := append(parserWarns, stmtctx.SQLWarn{Level: stmtctx.WarnLevelError, Err: err})
delete(cc.ctx.GetSessionVars().IsolationReadEngines, kv.TiFlash)
_, err = cc.handleStmt(ctx, stmt, warns, i == len(stmts)-1)
cc.ctx.GetSessionVars().IsolationReadEngines[kv.TiFlash] = struct{}{}
if err != nil {
break
}
}
}
if lastStmt != nil {
cc.onExtensionStmtEnd(lastStmt, sessVars.StmtCtx.TaskID != expiredStmtTaskID, err)
}
return err
}
// prefetchPointPlanKeys extracts the point keys in multi-statement query,
// use BatchGet to get the keys, so the values will be cached in the snapshot cache, save RPC call cost.
// For pessimistic transaction, the keys will be batch locked.
func (cc *clientConn) prefetchPointPlanKeys(ctx context.Context, stmts []ast.StmtNode) ([]plannercore.Plan, error) {
txn, err := cc.ctx.Txn(false)
if err != nil {
return nil, err
}
if !txn.Valid() {
// Only prefetch in-transaction query for simplicity.
// Later we can support out-transaction multi-statement query.
return nil, nil
}
vars := cc.ctx.GetSessionVars()
if vars.TxnCtx.IsPessimistic {
if vars.IsIsolation(ast.ReadCommitted) {
// TODO: to support READ-COMMITTED, we need to avoid getting new TS for each statement in the query.
return nil, nil
}
if vars.TxnCtx.GetForUpdateTS() != vars.TxnCtx.StartTS {
// Do not handle the case that ForUpdateTS is changed for simplicity.
return nil, nil
}
}
pointPlans := make([]plannercore.Plan, len(stmts))
var idxKeys []kv.Key //nolint: prealloc
var rowKeys []kv.Key //nolint: prealloc
handlePlan := func(p plannercore.PhysicalPlan, resetStmtCtxFn func()) error {
var tableID int64
switch v := p.(type) {
case *plannercore.PointGetPlan:
if v.PartitionInfo != nil {
tableID = v.PartitionInfo.ID
} else {
tableID = v.TblInfo.ID
}
if v.IndexInfo != nil {
resetStmtCtxFn()
idxKey, err1 := executor.EncodeUniqueIndexKey(cc.getCtx(), v.TblInfo, v.IndexInfo, v.IndexValues, tableID)
if err1 != nil {
return err1
}
idxKeys = append(idxKeys, idxKey)
} else {
rowKeys = append(rowKeys, tablecodec.EncodeRowKeyWithHandle(tableID, v.Handle))
}
case *plannercore.BatchPointGetPlan:
if v.PartitionInfos != nil && len(v.PartitionIDs) == 0 {
// skip when PartitionIDs is not initialized.
return nil
}
getPhysID := func(i int) int64 {
if v.PartitionInfos == nil {
return v.TblInfo.ID
}
return v.PartitionIDs[i]
}
if v.IndexInfo != nil {
resetStmtCtxFn()
for i, idxVals := range v.IndexValues {
idxKey, err1 := executor.EncodeUniqueIndexKey(cc.getCtx(), v.TblInfo, v.IndexInfo, idxVals, getPhysID(i))
if err1 != nil {
return err1
}
idxKeys = append(idxKeys, idxKey)
}
} else {
for i, handle := range v.Handles {
rowKeys = append(rowKeys, tablecodec.EncodeRowKeyWithHandle(getPhysID(i), handle))
}
}
}
return nil
}
sc := vars.StmtCtx
for i, stmt := range stmts {
if _, ok := stmt.(*ast.UseStmt); ok {
// If there is a "use db" statement, we shouldn't cache even if it's possible.
// Consider the scenario where there are statements that could execute on multiple
// schemas, but the schema is actually different.
return nil, nil
}
// TODO: the preprocess is run twice, we should find some way to avoid do it again.
if err = plannercore.Preprocess(ctx, cc.getCtx(), stmt); err != nil {
// error might happen, see https://github.com/pingcap/tidb/issues/39664
return nil, nil
}
p := plannercore.TryFastPlan(cc.ctx.Session, stmt)
pointPlans[i] = p
if p == nil {
continue
}
// Only support Update and Delete for now.
// TODO: support other point plans.
switch x := p.(type) {
case *plannercore.Update:
//nolint:forcetypeassert
updateStmt, ok := stmt.(*ast.UpdateStmt)
if !ok {
logutil.BgLogger().Warn("unexpected statement type for Update plan",
zap.String("type", fmt.Sprintf("%T", stmt)))
continue
}
err = handlePlan(x.SelectPlan, func() {
executor.ResetUpdateStmtCtx(sc, updateStmt, vars)
})
if err != nil {
return nil, err
}
case *plannercore.Delete:
deleteStmt, ok := stmt.(*ast.DeleteStmt)
if !ok {
logutil.BgLogger().Warn("unexpected statement type for Delete plan",
zap.String("type", fmt.Sprintf("%T", stmt)))
continue
}
err = handlePlan(x.SelectPlan, func() {
executor.ResetDeleteStmtCtx(sc, deleteStmt, vars)
})
if err != nil {
return nil, err
}
}
}
if len(idxKeys) == 0 && len(rowKeys) == 0 {
return pointPlans, nil
}
snapshot := txn.GetSnapshot()
idxVals, err1 := snapshot.BatchGet(ctx, idxKeys)
if err1 != nil {
return nil, err1
}
for idxKey, idxVal := range idxVals {
h, err2 := tablecodec.DecodeHandleInUniqueIndexValue(idxVal, false)
if err2 != nil {
return nil, err2
}
tblID := tablecodec.DecodeTableID(hack.Slice(idxKey))
rowKeys = append(rowKeys, tablecodec.EncodeRowKeyWithHandle(tblID, h))
}
if vars.TxnCtx.IsPessimistic {
allKeys := append(rowKeys, idxKeys...)
err = executor.LockKeys(ctx, cc.getCtx(), vars.LockWaitTimeout, allKeys...)
if err != nil {
// suppress the lock error, we are not going to handle it here for simplicity.
err = nil
logutil.BgLogger().Warn("lock keys error on prefetch", zap.Error(err))
}
} else {
_, err = snapshot.BatchGet(ctx, rowKeys)
if err != nil {
return nil, err
}
}
return pointPlans, nil
}
// The first return value indicates whether the call of handleStmt has no side effect and can be retried.
// Currently, the first return value is used to fall back to TiKV when TiFlash is down.
func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, warns []stmtctx.SQLWarn, lastStmt bool) (bool, error) {
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{})
reg := trace.StartRegion(ctx, "ExecuteStmt")
cc.audit(plugin.Starting)
rs, err := cc.ctx.ExecuteStmt(ctx, stmt)
reg.End()
// - If rs is not nil, the statement tracker detachment from session tracker
// is done in the `rs.Close` in most cases.
// - If the rs is nil and err is not nil, the detachment will be done in
// the `handleNoDelay`.
if rs != nil {
defer terror.Call(rs.Close)
}
if err != nil {
// If error is returned during the planner phase or the executor.Open
// phase, the rs will be nil, and StmtCtx.MemTracker StmtCtx.DiskTracker
// will not be detached. We need to detach them manually.
if sv := cc.ctx.GetSessionVars(); sv != nil && sv.StmtCtx != nil {
sv.StmtCtx.DetachMemDiskTracker()
}
return true, err
}
status := cc.ctx.Status()
if lastStmt {
cc.ctx.GetSessionVars().StmtCtx.AppendWarnings(warns)
} else {
status |= mysql.ServerMoreResultsExists
}
if rs != nil {
if cc.getStatus() == connStatusShutdown {
return false, exeerrors.ErrQueryInterrupted
}
if retryable, err := cc.writeResultSet(ctx, rs, false, status, 0); err != nil {
return retryable, err
}
return false, nil
}
handled, err := cc.handleFileTransInConn(ctx, status)
if handled {
if execStmt := cc.ctx.Value(session.ExecStmtVarKey); execStmt != nil {
//nolint:forcetypeassert
execStmt.(*executor.ExecStmt).FinishExecuteStmt(0, err, false)
}
}
if err != nil {
return false, err
}
return false, nil
}
func (cc *clientConn) handleFileTransInConn(ctx context.Context, status uint16) (bool, error) {
handled := false
loadDataInfo := cc.ctx.Value(executor.LoadDataVarKey)
if loadDataInfo != nil {
handled = true
defer cc.ctx.SetValue(executor.LoadDataVarKey, nil)
//nolint:forcetypeassert
if err := cc.handleLoadData(ctx, loadDataInfo.(*executor.LoadDataWorker)); err != nil {
return handled, err
}
}
loadStats := cc.ctx.Value(executor.LoadStatsVarKey)
if loadStats != nil {
handled = true
defer cc.ctx.SetValue(executor.LoadStatsVarKey, nil)
//nolint:forcetypeassert
if err := cc.handleLoadStats(ctx, loadStats.(*executor.LoadStatsInfo)); err != nil {
return handled, err
}
}
indexAdvise := cc.ctx.Value(executor.IndexAdviseVarKey)
if indexAdvise != nil {
handled = true
defer cc.ctx.SetValue(executor.IndexAdviseVarKey, nil)
//nolint:forcetypeassert
if err := cc.handleIndexAdvise(ctx, indexAdvise.(*executor.IndexAdviseInfo)); err != nil {
return handled, err
}
}
planReplayerLoad := cc.ctx.Value(executor.PlanReplayerLoadVarKey)
if planReplayerLoad != nil {
handled = true
defer cc.ctx.SetValue(executor.PlanReplayerLoadVarKey, nil)
//nolint:forcetypeassert
if err := cc.handlePlanReplayerLoad(ctx, planReplayerLoad.(*executor.PlanReplayerLoadInfo)); err != nil {
return handled, err
}
}
planReplayerDump := cc.ctx.Value(executor.PlanReplayerDumpVarKey)
if planReplayerDump != nil {
handled = true
defer cc.ctx.SetValue(executor.PlanReplayerDumpVarKey, nil)
//nolint:forcetypeassert
if err := cc.handlePlanReplayerDump(ctx, planReplayerDump.(*executor.PlanReplayerDumpInfo)); err != nil {
return handled, err
}
}
return handled, cc.writeOkWith(ctx, mysql.OKHeader, true, status)
}
// handleFieldList returns the field list for a table.
// The sql string is composed of a table name and a terminating character \x00.
func (cc *clientConn) handleFieldList(ctx context.Context, sql string) (err error) {
parts := strings.Split(sql, "\x00")
columns, err := cc.ctx.FieldList(parts[0])
if err != nil {
return err
}
data := cc.alloc.AllocWithLen(4, 1024)
cc.initResultEncoder(ctx)
defer cc.rsEncoder.Clean()
for _, column := range columns {
data = data[0:4]
data = column.DumpWithDefault(data, cc.rsEncoder)
if err := cc.writePacket(data); err != nil {
return err
}
}
if err := cc.writeEOF(ctx, cc.ctx.Status()); err != nil {
return err
}
return cc.flush(ctx)
}
// writeResultSet writes data into a result set and uses rs.Next to get row data back.
// If binary is true, the data would be encoded in BINARY format.
// serverStatus, a flag bit represents server information.
// fetchSize, the desired number of rows to be fetched each time when client uses cursor.
// retryable indicates whether the call of writeResultSet has no side effect and can be retried to correct error. The call
// has side effect in cursor mode or once data has been sent to client. Currently retryable is used to fallback to TiKV when
// TiFlash is down.
func (cc *clientConn) writeResultSet(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16, fetchSize int) (retryable bool, runErr error) {
defer func() {
// close ResultSet when cursor doesn't exist
r := recover()
if r == nil {
return
}
if str, ok := r.(string); !ok || !strings.HasPrefix(str, memory.PanicMemoryExceedWarnMsg) {
panic(r)
}
// TODO(jianzhang.zj: add metrics here)
runErr = errors.Errorf("%v", r)
logutil.Logger(ctx).Error("write query result panic", zap.Stringer("lastSQL", getLastStmtInConn{cc}), zap.Stack("stack"), zap.Any("recover", r))
}()
cc.initResultEncoder(ctx)
defer cc.rsEncoder.Clean()
if mysql.HasCursorExistsFlag(serverStatus) {
if err := cc.writeChunksWithFetchSize(ctx, rs, serverStatus, fetchSize); err != nil {
return false, err
}
return false, cc.flush(ctx)
}
if retryable, err := cc.writeChunks(ctx, rs, binary, serverStatus); err != nil {
return retryable, err
}
return false, cc.flush(ctx)
}
func (cc *clientConn) writeColumnInfo(columns []*column.Info) error {
data := cc.alloc.AllocWithLen(4, 1024)
data = dump.LengthEncodedInt(data, uint64(len(columns)))
if err := cc.writePacket(data); err != nil {
return err
}
for _, v := range columns {
data = data[0:4]
data = v.Dump(data, cc.rsEncoder)
if err := cc.writePacket(data); err != nil {
return err
}
}
return nil
}
// writeChunks writes data from a Chunk, which filled data by a ResultSet, into a connection.
// binary specifies the way to dump data. It throws any error while dumping data.
// serverStatus, a flag bit represents server information
// The first return value indicates whether error occurs at the first call of ResultSet.Next.
func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16) (bool, error) {
data := cc.alloc.AllocWithLen(4, 1024)
req := rs.NewChunk(cc.chunkAlloc)
gotColumnInfo := false
firstNext := true
validNextCount := 0
var start time.Time
var stmtDetail *execdetails.StmtExecDetails
stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey)
if stmtDetailRaw != nil {
//nolint:forcetypeassert
stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails)
}
for {
failpoint.Inject("fetchNextErr", func(value failpoint.Value) {
//nolint:forcetypeassert
switch value.(string) {
case "firstNext":
failpoint.Return(firstNext, storeerr.ErrTiFlashServerTimeout)
case "secondNext":
if !firstNext {
failpoint.Return(firstNext, storeerr.ErrTiFlashServerTimeout)
}
case "secondNextAndRetConflict":
if !firstNext && validNextCount > 1 {
failpoint.Return(firstNext, kv.ErrWriteConflict)
}
}
})
// Here server.tidbResultSet implements Next method.
err := rs.Next(ctx, req)
if err != nil {
return firstNext, err
}
if !gotColumnInfo {
// We need to call Next before we get columns.
// Otherwise, we will get incorrect columns info.
columns := rs.Columns()
if stmtDetail != nil {
start = time.Now()
}
if err = cc.writeColumnInfo(columns); err != nil {
return false, err
}
if cc.capability&mysql.ClientDeprecateEOF == 0 {
// metadata only needs EOF marker for old clients without ClientDeprecateEOF
if err = cc.writeEOF(ctx, serverStatus); err != nil {
return false, err
}
}
if stmtDetail != nil {
stmtDetail.WriteSQLRespDuration += time.Since(start)
}
gotColumnInfo = true
}
rowCount := req.NumRows()
if rowCount == 0 {
break
}
validNextCount++
firstNext = false
reg := trace.StartRegion(ctx, "WriteClientConn")
if stmtDetail != nil {
start = time.Now()
}
for i := 0; i < rowCount; i++ {
data = data[0:4]
if binary {
data, err = dumpBinaryRow(data, rs.Columns(), req.GetRow(i), cc.rsEncoder)
} else {
data, err = dumpTextRow(data, rs.Columns(), req.GetRow(i), cc.rsEncoder)
}
if err != nil {
reg.End()
return false, err
}
if err = cc.writePacket(data); err != nil {
reg.End()
return false, err
}
}
reg.End()
if stmtDetail != nil {
stmtDetail.WriteSQLRespDuration += time.Since(start)
}
}
if stmtDetail != nil {
start = time.Now()
}
err := cc.writeEOF(ctx, serverStatus)
if stmtDetail != nil {
stmtDetail.WriteSQLRespDuration += time.Since(start)
}
return false, err
}
// writeChunksWithFetchSize writes data from a Chunk, which filled data by a ResultSet, into a connection.
// binary specifies the way to dump data. It throws any error while dumping data.
// serverStatus, a flag bit represents server information.
// fetchSize, the desired number of rows to be fetched each time when client uses cursor.
func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet, serverStatus uint16, fetchSize int) error {
fetchedRows := rs.GetFetchedRows()
// tell the client COM_STMT_FETCH has finished by setting proper serverStatus,
// and close ResultSet.
if len(fetchedRows) == 0 {
serverStatus &^= mysql.ServerStatusCursorExists
serverStatus |= mysql.ServerStatusLastRowSend
return cc.writeEOF(ctx, serverStatus)
}
// construct the rows sent to the client according to fetchSize.
var curRows []chunk.Row
if fetchSize < len(fetchedRows) {
curRows = fetchedRows[:fetchSize]
fetchedRows = fetchedRows[fetchSize:]
} else {
curRows = fetchedRows
fetchedRows = fetchedRows[:0]
}
rs.StoreFetchedRows(fetchedRows)
data := cc.alloc.AllocWithLen(4, 1024)
var stmtDetail *execdetails.StmtExecDetails
stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey)
if stmtDetailRaw != nil {
//nolint:forcetypeassert
stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails)
}
var (
err error
start time.Time
)
if stmtDetail != nil {
start = time.Now()
}
for _, row := range curRows {
data = data[0:4]
data, err = dumpBinaryRow(data, rs.Columns(), row, cc.rsEncoder)
if err != nil {
return err
}
if err = cc.writePacket(data); err != nil {
return err
}
}
if stmtDetail != nil {
stmtDetail.WriteSQLRespDuration += time.Since(start)
}
if cl, ok := rs.(fetchNotifier); ok {
cl.OnFetchReturned()
}
if stmtDetail != nil {
start = time.Now()
}
err = cc.writeEOF(ctx, serverStatus)
if stmtDetail != nil {
stmtDetail.WriteSQLRespDuration += time.Since(start)
}
return err
}
func (cc *clientConn) setConn(conn net.Conn) {
cc.bufReadConn = newBufferedReadConn(conn)
if cc.pkt == nil {
cc.pkt = newPacketIO(cc.bufReadConn)
} else {
// Preserve current sequence number.
cc.pkt.setBufferedReadConn(cc.bufReadConn)
}
}
func (cc *clientConn) upgradeToTLS(tlsConfig *tls.Config) error {
// Important: read from buffered reader instead of the original net.Conn because it may contain data we need.
tlsConn := tls.Server(cc.bufReadConn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return err
}
cc.setConn(tlsConn)
cc.tlsConn = tlsConn
return nil
}
func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error {
user, data := util2.ParseNullTermString(data)
cc.user = string(hack.String(user))
if len(data) < 1 {
return mysql.ErrMalformPacket
}
passLen := int(data[0])
data = data[1:]
if passLen > len(data) {
return mysql.ErrMalformPacket
}
pass := data[:passLen]
data = data[passLen:]
dbName, data := util2.ParseNullTermString(data)
cc.dbname = string(hack.String(dbName))
pluginName := ""
if len(data) > 0 {
// skip character set
if cc.capability&mysql.ClientProtocol41 > 0 && len(data) >= 2 {
data = data[2:]
}
if cc.capability&mysql.ClientPluginAuth > 0 && len(data) > 0 {
pluginNameB, _ := util2.ParseNullTermString(data)
pluginName = string(hack.String(pluginNameB))
}
}
if err := cc.ctx.Close(); err != nil {
logutil.Logger(ctx).Debug("close old context failed", zap.Error(err))
}
// session was closed by `ctx.Close` and should `openSession` explicitly to renew session.
// `openSession` won't run again in `openSessionAndDoAuth` because ctx is not nil.
err := cc.openSession()
if err != nil {
return err
}
fakeResp := &handshakeResponse41{
Auth: pass,
AuthPlugin: pluginName,
Capability: cc.capability,
}
if fakeResp.AuthPlugin != "" {
failpoint.Inject("ChangeUserAuthSwitch", func(val failpoint.Value) {
failpoint.Return(errors.Errorf("%v", val))
})
newpass, err := cc.checkAuthPlugin(ctx, fakeResp)
if err != nil {
return err
}
if len(newpass) > 0 {
fakeResp.Auth = newpass
}
}
if err := cc.openSessionAndDoAuth(fakeResp.Auth, fakeResp.AuthPlugin); err != nil {
return err
}
return cc.handleCommonConnectionReset(ctx)
}
func (cc *clientConn) handleResetConnection(ctx context.Context) error {
user := cc.ctx.GetSessionVars().User
err := cc.ctx.Close()
if err != nil {
logutil.Logger(ctx).Debug("close old context failed", zap.Error(err))
}
var tlsStatePtr *tls.ConnectionState
if cc.tlsConn != nil {
tlsState := cc.tlsConn.ConnectionState()
tlsStatePtr = &tlsState
}
tidbCtx, err := cc.server.driver.OpenCtx(cc.connectionID, cc.capability, cc.collation, cc.dbname, tlsStatePtr, cc.extensions)
if err != nil {
return err
}
cc.setCtx(tidbCtx)
if !cc.ctx.AuthWithoutVerification(user) {
return errors.New("Could not reset connection")
}
if cc.dbname != "" { // Restore the current DB
_, err = cc.useDB(context.Background(), cc.dbname)
if err != nil {
return err
}
}
cc.ctx.SetSessionManager(cc.server)
return cc.handleCommonConnectionReset(ctx)
}
func (cc *clientConn) handleCommonConnectionReset(ctx context.Context) error {
connectionInfo := cc.connectInfo()
cc.ctx.GetSessionVars().ConnectionInfo = connectionInfo
cc.onExtensionConnEvent(extension.ConnReset, nil)
err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
connInfo := cc.ctx.GetSessionVars().ConnectionInfo
err := authPlugin.OnConnectionEvent(context.Background(), plugin.ChangeUser, connInfo)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return err
}
return cc.writeOK(ctx)
}
// safe to noop except 0x01 "FLUSH PRIVILEGES"
func (cc *clientConn) handleRefresh(ctx context.Context, subCommand byte) error {
if subCommand == 0x01 {
if err := cc.handleQuery(ctx, "FLUSH PRIVILEGES"); err != nil {
return err
}
}
return cc.writeOK(ctx)
}
var _ fmt.Stringer = getLastStmtInConn{}
type getLastStmtInConn struct {
*clientConn
}
func (cc getLastStmtInConn) String() string {
if len(cc.lastPacket) == 0 {
return ""
}
cmd, data := cc.lastPacket[0], cc.lastPacket[1:]
switch cmd {
case mysql.ComInitDB:
return "Use " + string(data)
case mysql.ComFieldList:
return "ListFields " + string(data)
case mysql.ComQuery, mysql.ComStmtPrepare:
sql := string(hack.String(data))
if cc.ctx.GetSessionVars().EnableRedactLog {
sql = parser.Normalize(sql)
}
return tidbutil.QueryStrForLog(sql)
case mysql.ComStmtExecute, mysql.ComStmtFetch:
stmtID := binary.LittleEndian.Uint32(data[0:4])
return tidbutil.QueryStrForLog(cc.preparedStmt2String(stmtID))
case mysql.ComStmtClose, mysql.ComStmtReset:
stmtID := binary.LittleEndian.Uint32(data[0:4])
return mysql.Command2Str[cmd] + " " + strconv.Itoa(int(stmtID))
default:
if cmdStr, ok := mysql.Command2Str[cmd]; ok {
return cmdStr
}
return string(hack.String(data))
}
}
// PProfLabel return sql label used to tag pprof.
func (cc getLastStmtInConn) PProfLabel() string {
if len(cc.lastPacket) == 0 {
return ""
}
cmd, data := cc.lastPacket[0], cc.lastPacket[1:]
switch cmd {
case mysql.ComInitDB:
return "UseDB"
case mysql.ComFieldList:
return "ListFields"
case mysql.ComStmtClose:
return "CloseStmt"
case mysql.ComStmtReset:
return "ResetStmt"
case mysql.ComQuery, mysql.ComStmtPrepare:
return parser.Normalize(tidbutil.QueryStrForLog(string(hack.String(data))))
case mysql.ComStmtExecute, mysql.ComStmtFetch:
stmtID := binary.LittleEndian.Uint32(data[0:4])
return tidbutil.QueryStrForLog(cc.preparedStmt2StringNoArgs(stmtID))
default:
return ""
}
}
var _ conn.AuthConn = &clientConn{}
// WriteAuthMoreData implements `conn.AuthConn` interface
func (cc *clientConn) WriteAuthMoreData(data []byte) error {
// See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_more_data.html
// the `AuthMoreData` packet is just an arbitrary binary slice with a byte 0x1 as prefix.
return cc.writePacket(append([]byte{0, 0, 0, 0, 1}, data...))
}
// ReadPacket implements `conn.AuthConn` interface
func (cc *clientConn) ReadPacket() ([]byte, error) {
return cc.readPacket()
}
// Flush implements `conn.AuthConn` interface
func (cc *clientConn) Flush(ctx context.Context) error {
return cc.flush(ctx)
}