diff --git a/domain/BUILD.bazel b/domain/BUILD.bazel index 46f59b5f0ffbd..cb2e31915513d 100644 --- a/domain/BUILD.bazel +++ b/domain/BUILD.bazel @@ -65,6 +65,7 @@ go_library( "//util/execdetails", "//util/expensivequery", "//util/gctuner", + "//util/globalconn", "//util/intest", "//util/logutil", "//util/memory", diff --git a/domain/db_test.go b/domain/db_test.go index 02f716a27b3a8..06b82ea2f9fd5 100644 --- a/domain/db_test.go +++ b/domain/db_test.go @@ -83,7 +83,6 @@ func TestNormalSessionPool(t *testing.T) { svr, err := server.NewServer(conf, nil) require.NoError(t, err) svr.SetDomain(domain) - svr.InitGlobalConnID(domain.ServerID) info.SetSessionManager(svr) pool := domain.SysSessionPool() @@ -117,7 +116,6 @@ func TestAbnormalSessionPool(t *testing.T) { svr, err := server.NewServer(conf, nil) require.NoError(t, err) svr.SetDomain(domain) - svr.InitGlobalConnID(domain.ServerID) info.SetSessionManager(svr) pool := domain.SysSessionPool() diff --git a/domain/domain.go b/domain/domain.go index d4c3b436a2cac..7def7349e1127 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -75,6 +75,7 @@ import ( "github.com/pingcap/tidb/util/etcd" "github.com/pingcap/tidb/util/expensivequery" "github.com/pingcap/tidb/util/gctuner" + "github.com/pingcap/tidb/util/globalconn" "github.com/pingcap/tidb/util/intest" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/memory" @@ -162,8 +163,10 @@ type Domain struct { serverID uint64 serverIDSession *concurrency.Session isLostConnectionToPD atomicutil.Int32 // !0: true, 0: false. - onClose func() - sysExecutorFactory func(*Domain) (pools.Resource, error) + connIDAllocator globalconn.Allocator + + onClose func() + sysExecutorFactory func(*Domain) (pools.Resource, error) sysProcesses SysProcesses @@ -1141,6 +1144,8 @@ func (do *Domain) Init( } if config.GetGlobalConfig().EnableGlobalKill { + do.connIDAllocator = globalconn.NewGlobalAllocator(do.ServerID) + if do.etcdClient != nil { err := do.acquireServerID(ctx) if err != nil { @@ -1156,6 +1161,8 @@ func (do *Domain) Init( // set serverID for standalone deployment to enable 'KILL'. atomic.StoreUint64(&do.serverID, serverIDForStandalone) } + } else { + do.connIDAllocator = globalconn.NewSimpleAllocator() } // step 1: prepare the info/schema syncer which domain reload needed. @@ -1509,6 +1516,7 @@ func (p *sessionPool) Put(resource pools.Resource) { resource.Close() } } + func (p *sessionPool) Close() { p.mu.Lock() if p.mu.closed { @@ -2066,7 +2074,7 @@ func (do *Domain) StatsHandle() *handle.Handle { // CreateStatsHandle is used only for test. func (do *Domain) CreateStatsHandle(ctx, initStatsCtx sessionctx.Context) error { - h, err := handle.NewHandle(ctx, initStatsCtx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.ServerID) + h, err := handle.NewHandle(ctx, initStatsCtx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.GetAutoAnalyzeProcID) if err != nil { return err } @@ -2142,7 +2150,7 @@ func (do *Domain) LoadAndUpdateStatsLoop(ctxs []sessionctx.Context, initStatsCtx // It should be called only once in BootstrapSession. func (do *Domain) UpdateTableStatsLoop(ctx, initStatsCtx sessionctx.Context) error { ctx.GetSessionVars().InRestrictedSQL = true - statsHandle, err := handle.NewHandle(ctx, initStatsCtx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.ServerID) + statsHandle, err := handle.NewHandle(ctx, initStatsCtx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.GetAutoAnalyzeProcID) if err != nil { return err } @@ -2532,12 +2540,31 @@ func (do *Domain) IsLostConnectionToPD() bool { return do.isLostConnectionToPD.Load() != 0 } +// NextConnID return next connection ID. +func (do *Domain) NextConnID() uint64 { + return do.connIDAllocator.NextID() +} + +// ReleaseConnID releases connection ID. +func (do *Domain) ReleaseConnID(connID uint64) { + do.connIDAllocator.Release(connID) +} + +// GetAutoAnalyzeProcID returns processID for auto analyze +// TODO: support IDs for concurrent auto-analyze +func (do *Domain) GetAutoAnalyzeProcID() uint64 { + return do.connIDAllocator.GetReservedConnID(reservedConnAnalyze) +} + const ( serverIDEtcdPath = "/tidb/server_id" refreshServerIDRetryCnt = 3 acquireServerIDRetryInterval = 300 * time.Millisecond acquireServerIDTimeout = 10 * time.Second retrieveServerIDSessionTimeout = 10 * time.Second + + // reservedConnXXX must be within [0, globalconn.ReservedCount) + reservedConnAnalyze = 0 ) var ( @@ -2631,8 +2658,8 @@ func (do *Domain) acquireServerID(ctx context.Context) error { } for { - // get a random serverID: [1, MaxServerID] - randServerID := rand.Int63n(int64(util.MaxServerID)) + 1 // #nosec G404 + // get a random serverID: [1, MaxServerID64] + randServerID := rand.Int63n(int64(globalconn.MaxServerID64)) + 1 // #nosec G404 key := fmt.Sprintf("%s/%v", serverIDEtcdPath, randServerID) cmp := clientv3.Compare(clientv3.CreateRevision(key), "=", 0) value := "0" diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel index 7506681c5771e..38d7e33581ca6 100644 --- a/executor/BUILD.bazel +++ b/executor/BUILD.bazel @@ -185,6 +185,7 @@ go_library( "//util/execdetails", "//util/format", "//util/gcutil", + "//util/globalconn", "//util/hack", "//util/hint", "//util/intest", @@ -428,6 +429,7 @@ go_test( "//util/disk", "//util/execdetails", "//util/gcutil", + "//util/globalconn", "//util/hack", "//util/logutil", "//util/mathutil", diff --git a/executor/analyze.go b/executor/analyze.go index 74d665aedf685..3aa5b08f47c46 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -171,7 +171,7 @@ func (e *AnalyzeExec) Next(ctx context.Context, _ *chunk.Chunk) error { } failpoint.Inject("mockKillPendingAnalyzeJob", func() { dom := domain.GetDomain(e.ctx) - dom.SysProcTracker().KillSysProcess(util.GetAutoAnalyzeProcID(dom.ServerID)) + dom.SysProcTracker().KillSysProcess(dom.GetAutoAnalyzeProcID()) }) for _, task := range tasks { taskCh <- task @@ -194,7 +194,7 @@ func (e *AnalyzeExec) Next(ctx context.Context, _ *chunk.Chunk) error { } failpoint.Inject("mockKillFinishedAnalyzeJob", func() { dom := domain.GetDomain(e.ctx) - dom.SysProcTracker().KillSysProcess(util.GetAutoAnalyzeProcID(dom.ServerID)) + dom.SysProcTracker().KillSysProcess(dom.GetAutoAnalyzeProcID()) }) // If we enabled dynamic prune mode, then we need to generate global stats here for partition tables. diff --git a/executor/analyze_col.go b/executor/analyze_col.go index 4679cc4a6f4bc..e5bade4d3dca4 100644 --- a/executor/analyze_col.go +++ b/executor/analyze_col.go @@ -176,7 +176,7 @@ func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range, needExtStats boo for { failpoint.Inject("mockKillRunningV1AnalyzeJob", func() { dom := domain.GetDomain(e.ctx) - dom.SysProcTracker().KillSysProcess(util.GetAutoAnalyzeProcID(dom.ServerID)) + dom.SysProcTracker().KillSysProcess(dom.GetAutoAnalyzeProcID()) }) if atomic.LoadUint32(&e.ctx.GetSessionVars().Killed) == 1 { return nil, nil, nil, nil, nil, errors.Trace(exeerrors.ErrQueryInterrupted) diff --git a/executor/analyze_col_v2.go b/executor/analyze_col_v2.go index f976e844c79f6..f7c1f10e6b5b4 100644 --- a/executor/analyze_col_v2.go +++ b/executor/analyze_col_v2.go @@ -817,7 +817,7 @@ func readDataAndSendTask(ctx sessionctx.Context, handler *tableResultHandler, me for { failpoint.Inject("mockKillRunningV2AnalyzeJob", func() { dom := domain.GetDomain(ctx) - dom.SysProcTracker().KillSysProcess(util.GetAutoAnalyzeProcID(dom.ServerID)) + dom.SysProcTracker().KillSysProcess(dom.GetAutoAnalyzeProcID()) }) if atomic.LoadUint32(&ctx.GetSessionVars().Killed) == 1 { return errors.Trace(exeerrors.ErrQueryInterrupted) diff --git a/executor/analyze_idx.go b/executor/analyze_idx.go index 8ba61499dac12..a6acd640479e8 100644 --- a/executor/analyze_idx.go +++ b/executor/analyze_idx.go @@ -31,7 +31,6 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/dbterror/exeerrors" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/ranger" @@ -195,7 +194,7 @@ func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, nee for { failpoint.Inject("mockKillRunningAnalyzeIndexJob", func() { dom := domain.GetDomain(e.ctx) - dom.SysProcTracker().KillSysProcess(util.GetAutoAnalyzeProcID(dom.ServerID)) + dom.SysProcTracker().KillSysProcess(dom.GetAutoAnalyzeProcID()) }) if atomic.LoadUint32(&e.ctx.GetSessionVars().Killed) == 1 { return nil, nil, nil, nil, errors.Trace(exeerrors.ErrQueryInterrupted) diff --git a/executor/simple.go b/executor/simple.go index 6cb42e4a195bd..d1be1a506d3ba 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -52,6 +52,7 @@ import ( "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/dbterror/exeerrors" + "github.com/pingcap/tidb/util/globalconn" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/mathutil" @@ -2574,7 +2575,7 @@ func (e *SimpleExec) executeKillStmt(ctx context.Context, s *ast.KillStmt) error return nil } - connID, isTruncated, err := util.ParseGlobalConnID(s.ConnectionID) + gcid, isTruncated, err := globalconn.ParseConnID(s.ConnectionID) if err != nil { err1 := errors.New("Parse ConnectionID failed: " + err.Error()) e.ctx.GetSessionVars().StmtCtx.AppendWarning(err1) @@ -2590,8 +2591,8 @@ func (e *SimpleExec) executeKillStmt(ctx context.Context, s *ast.KillStmt) error return nil } - if connID.ServerID != sm.ServerID() { - if err := killRemoteConn(ctx, e.ctx, &connID, s.Query); err != nil { + if gcid.ServerID != sm.ServerID() { + if err := killRemoteConn(ctx, e.ctx, &gcid, s.Query); err != nil { err1 := errors.New("KILL remote connection failed: " + err.Error()) e.ctx.GetSessionVars().StmtCtx.AppendWarning(err1) } @@ -2602,14 +2603,14 @@ func (e *SimpleExec) executeKillStmt(ctx context.Context, s *ast.KillStmt) error return nil } -func killRemoteConn(ctx context.Context, sctx sessionctx.Context, connID *util.GlobalConnID, query bool) error { - if connID.ServerID == 0 { +func killRemoteConn(ctx context.Context, sctx sessionctx.Context, gcid *globalconn.GCID, query bool) error { + if gcid.ServerID == 0 { return errors.New("Unexpected ZERO ServerID. Please file a bug to the TiDB Team") } killExec := &tipb.Executor{ Tp: tipb.ExecType_TypeKill, - Kill: &tipb.Kill{ConnID: connID.ID(), Query: query}, + Kill: &tipb.Kill{ConnID: gcid.ToConnID(), Query: query}, } dagReq := &tipb.DAGRequest{} @@ -2628,7 +2629,7 @@ func killRemoteConn(ctx context.Context, sctx sessionctx.Context, connID *util.G SetFromSessionVars(sctx.GetSessionVars()). SetFromInfoSchema(sctx.GetInfoSchema()). SetStoreType(kv.TiDB). - SetTiDBServerID(connID.ServerID). + SetTiDBServerID(gcid.ServerID). Build() if err != nil { return err @@ -2639,8 +2640,8 @@ func killRemoteConn(ctx context.Context, sctx sessionctx.Context, connID *util.G return err } - logutil.BgLogger().Info("Killed remote connection", zap.Uint64("serverID", connID.ServerID), - zap.Uint64("conn", connID.ID()), zap.Bool("query", query)) + logutil.BgLogger().Info("Killed remote connection", zap.Uint64("serverID", gcid.ServerID), + zap.Uint64("conn", gcid.ToConnID()), zap.Bool("query", query)) return err } diff --git a/executor/simple_test.go b/executor/simple_test.go index b0bda248bb5ba..d53269d6336c7 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -26,7 +26,7 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/server" "github.com/pingcap/tidb/testkit" - "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/globalconn" "github.com/stretchr/testify/require" ) @@ -71,11 +71,12 @@ func TestKillStmt(t *testing.T) { // excceed int64 tk.MustExec("kill 9223372036854775808") // 9223372036854775808 == 2^63 result = tk.MustQuery("show warnings") - result.Check(testkit.Rows("Warning 1105 Parse ConnectionID failed: Unexpected connectionID excceeds int64")) + result.Check(testkit.Rows("Warning 1105 Parse ConnectionID failed: unexpected connectionID exceeds int64")) // local kill - killConnID := util.NewGlobalConnID(connID, true) - tk.MustExec("kill " + strconv.FormatUint(killConnID.ID(), 10)) + connIDAllocator := globalconn.NewGlobalAllocator(dom.ServerID) + killConnID := connIDAllocator.NextID() + tk.MustExec("kill " + strconv.FormatUint(killConnID, 10)) result = tk.MustQuery("show warnings") result.Check(testkit.Rows()) diff --git a/go.mod b/go.mod index 407f9d672a727..3a34e9617aa07 100644 --- a/go.mod +++ b/go.mod @@ -66,6 +66,7 @@ require ( github.com/lestrrat-go/jwx/v2 v2.0.6 github.com/mgechev/revive v1.3.2 github.com/ngaut/pools v0.0.0-20180318154953-b7bc8c42aac7 + github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef github.com/nishanths/predeclared v0.2.2 github.com/opentracing/basictracer-go v1.0.0 github.com/opentracing/opentracing-go v1.2.0 @@ -225,7 +226,6 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect github.com/ncw/directio v1.0.5 // indirect - github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/petermattis/goid v0.0.0-20211229010228-4d14c490ee36 // indirect diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 293b27be63d63..721cc5dbe18a1 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -3529,7 +3529,7 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan, b.visitInfo = appendDynamicVisitInfo(b.visitInfo, "CONNECTION_ADMIN", false, err) b.visitInfo = appendVisitInfoIsRestrictedUser(b.visitInfo, b.ctx, &auth.UserIdentity{Username: pi.User, Hostname: pi.Host}, "RESTRICTED_CONNECTION_ADMIN") } - } else if raw.ConnectionID == util2.GetAutoAnalyzeProcID(domain.GetDomain(b.ctx).ServerID) { + } else if raw.ConnectionID == domain.GetDomain(b.ctx).GetAutoAnalyzeProcID() { // Only the users with SUPER or CONNECTION_ADMIN privilege can kill auto analyze. err := ErrSpecificAccessDenied.GenWithStackByArgs("SUPER or CONNECTION_ADMIN") b.visitInfo = appendDynamicVisitInfo(b.visitInfo, "CONNECTION_ADMIN", false, err) diff --git a/server/conn.go b/server/conn.go index 12e2846035c47..848eeb4a80456 100644 --- a/server/conn.go +++ b/server/conn.go @@ -110,7 +110,7 @@ const ( func newClientConn(s *Server) *clientConn { return &clientConn{ server: s, - connectionID: s.globalConnID.NextID(), + connectionID: s.dom.NextConnID(), collation: mysql.DefaultCollationID, alloc: arena.NewAllocator(32 * 1024), chunkAlloc: chunk.NewAllocator(), @@ -330,6 +330,7 @@ func (cc *clientConn) Close() error { func closeConn(cc *clientConn, connections int) error { metrics.ConnGauge.Set(float64(connections)) + cc.server.dom.ReleaseConnID(cc.connectionID) if cc.bufReadConn != nil { err := cc.bufReadConn.Close() if err != nil { diff --git a/server/extract_test.go b/server/extract_test.go index a595fbe7168bc..d84079df11660 100644 --- a/server/extract_test.go +++ b/server/extract_test.go @@ -50,6 +50,10 @@ func TestExtractHandler(t *testing.T) { require.NoError(t, err) defer server.Close() + dom, err := session.GetDomain(store) + require.NoError(t, err) + server.SetDomain(dom) + client.port = getPortFromTCPAddr(server.listener.Addr()) client.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { @@ -62,8 +66,6 @@ func TestExtractHandler(t *testing.T) { prepareData4ExtractPlanTask(t, client) time.Sleep(time.Second) endTime := time.Now() - dom, err := session.GetDomain(store) - require.NoError(t, err) eh := &ExtractTaskServeHandler{extractHandler: dom.GetExtractHandle()} router := mux.NewRouter() router.Handle("/extract_task/dump", eh) diff --git a/server/http_handler_test.go b/server/http_handler_test.go index c28967abf06ee..ca861c18d0172 100644 --- a/server/http_handler_test.go +++ b/server/http_handler_test.go @@ -459,6 +459,7 @@ func (ts *basicHTTPHandlerTestSuite) startServer(t *testing.T) { ts.port = getPortFromTCPAddr(server.listener.Addr()) ts.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) ts.server = server + ts.server.SetDomain(ts.domain) go func() { err := server.Run() require.NoError(t, err) diff --git a/server/mock_conn.go b/server/mock_conn.go index 692a66c888a7d..8f54281512449 100644 --- a/server/mock_conn.go +++ b/server/mock_conn.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/auth" tmysql "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/util/arena" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/intest" @@ -90,6 +91,9 @@ func CreateMockServer(t *testing.T, store kv.Storage) *Server { cfg.Security.AutoTLS = false server, err := NewServer(cfg, tidbdrv) require.NoError(t, err) + dom, err := session.GetDomain(store) + require.NoError(t, err) + server.SetDomain(dom) return server } diff --git a/server/optimize_trace_test.go b/server/optimize_trace_test.go index 7cffc8cd50c2e..6da9a021401e9 100644 --- a/server/optimize_trace_test.go +++ b/server/optimize_trace_test.go @@ -50,6 +50,10 @@ func TestDumpOptimizeTraceAPI(t *testing.T) { require.NoError(t, err) defer server.Close() + dom, err := session.GetDomain(store) + require.NoError(t, err) + server.SetDomain(dom) + client.port = getPortFromTCPAddr(server.listener.Addr()) client.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { @@ -58,8 +62,6 @@ func TestDumpOptimizeTraceAPI(t *testing.T) { }() client.waitUntilServerOnline() - dom, err := session.GetDomain(store) - require.NoError(t, err) statsHandler := &StatsHandler{dom} otHandler := &OptimizeTraceHandler{} diff --git a/server/plan_replayer_test.go b/server/plan_replayer_test.go index 2f2308efdd1a9..e863ecdaf71a7 100644 --- a/server/plan_replayer_test.go +++ b/server/plan_replayer_test.go @@ -44,6 +44,10 @@ func TestDumpPlanReplayerAPI(t *testing.T) { require.NoError(t, err) defer server.Close() + dom, err := session.GetDomain(store) + require.NoError(t, err) + server.SetDomain(dom) + client.port = getPortFromTCPAddr(server.listener.Addr()) client.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { @@ -52,8 +56,6 @@ func TestDumpPlanReplayerAPI(t *testing.T) { }() client.waitUntilServerOnline() - dom, err := session.GetDomain(store) - require.NoError(t, err) statsHandler := &StatsHandler{dom} planReplayerHandler := &PlanReplayerHandler{} diff --git a/server/server.go b/server/server.go index a1047e1593c8c..937fb0a4da7c0 100644 --- a/server/server.go +++ b/server/server.go @@ -135,9 +135,8 @@ type Server struct { rwlock sync.RWMutex clients map[uint64]*clientConn - capability uint32 - dom *domain.Domain - globalConnID util.GlobalConnID + capability uint32 + dom *domain.Domain statusAddr string statusListener net.Listener @@ -181,11 +180,6 @@ func (s *Server) SetDomain(dom *domain.Domain) { s.dom = dom } -// InitGlobalConnID initialize global connection id. -func (s *Server) InitGlobalConnID(serverIDGetter func() uint64) { - s.globalConnID = util.NewGlobalConnIDWithGetter(serverIDGetter, true) -} - // newConn creates a new *clientConn from a net.Conn. // It allocates a connection ID and random salt data for authentication. func (s *Server) newConn(conn net.Conn) *clientConn { @@ -210,7 +204,6 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { driver: driver, concurrentLimiter: NewTokenLimiter(cfg.TokenLimit), clients: make(map[uint64]*clientConn), - globalConnID: util.NewGlobalConnID(0, true), internalSessions: make(map[interface{}]struct{}, 100), health: uatomic.NewBool(true), inShutdownMode: uatomic.NewBool(false), @@ -940,6 +933,11 @@ func (s *Server) ServerID() uint64 { return s.dom.ServerID() } +// GetAutoAnalyzeProcID implements SessionManager interface. +func (s *Server) GetAutoAnalyzeProcID() uint64 { + return s.dom.GetAutoAnalyzeProcID() +} + // StoreInternalSession implements SessionManager interface. // @param addr The address of a session.session struct variable func (s *Server) StoreInternalSession(se interface{}) { @@ -961,7 +959,7 @@ func (s *Server) GetInternalSessionStartTSList() []uint64 { s.sessionMapMutex.Lock() defer s.sessionMapMutex.Unlock() tsList := make([]uint64, 0, len(s.internalSessions)) - analyzeProcID := util.GetAutoAnalyzeProcID(s.ServerID) + analyzeProcID := s.GetAutoAnalyzeProcID() for se := range s.internalSessions { if ts, processInfoID := session.GetStartTSFromSession(se); ts != 0 { if processInfoID == analyzeProcID { diff --git a/server/statistics_handler_test.go b/server/statistics_handler_test.go index e0ecc7ba853f0..033147b7485a8 100644 --- a/server/statistics_handler_test.go +++ b/server/statistics_handler_test.go @@ -47,6 +47,10 @@ func TestDumpStatsAPI(t *testing.T) { require.NoError(t, err) defer server.Close() + dom, err := session.GetDomain(store) + require.NoError(t, err) + server.SetDomain(dom) + client.port = getPortFromTCPAddr(server.listener.Addr()) client.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) go func() { @@ -55,8 +59,6 @@ func TestDumpStatsAPI(t *testing.T) { }() client.waitUntilServerOnline() - dom, err := session.GetDomain(store) - require.NoError(t, err) statsHandler := &StatsHandler{dom} prepareData(t, client, statsHandler) diff --git a/server/tidb_serial_test.go b/server/tidb_serial_test.go index fdb1987f16104..340eac4f0b272 100644 --- a/server/tidb_serial_test.go +++ b/server/tidb_serial_test.go @@ -160,6 +160,7 @@ func TestTLSAuto(t *testing.T) { require.NoError(t, err) server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run() @@ -205,6 +206,7 @@ func TestTLSBasic(t *testing.T) { } server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run() @@ -268,6 +270,7 @@ func TestTLSVerify(t *testing.T) { } server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) defer server.Close() cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { @@ -372,6 +375,7 @@ func TestErrorNoRollback(t *testing.T) { } server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run() @@ -516,6 +520,7 @@ func TestReloadTLS(t *testing.T) { } server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run() diff --git a/server/tidb_test.go b/server/tidb_test.go index b51aa67ccd25c..3954f8447ca9c 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -109,7 +109,6 @@ func createTidbTestSuiteWithCfg(t *testing.T, cfg *config.Config) *tidbTestSuite ts.statusPort = getPortFromTCPAddr(server.statusListener.Addr()) ts.server = server ts.server.SetDomain(ts.domain) - ts.server.InitGlobalConnID(ts.domain.ServerID) ts.domain.InfoSyncer().SetSessionManager(ts.server) go func() { err := ts.server.Run() @@ -398,6 +397,7 @@ func TestSocketForwarding(t *testing.T) { server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run() @@ -430,6 +430,7 @@ func TestSocket(t *testing.T) { server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) go func() { err := server.Run() require.NoError(t, err) @@ -464,6 +465,7 @@ func TestSocketAndIp(t *testing.T) { server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run() @@ -628,6 +630,7 @@ func TestOnlySocket(t *testing.T) { server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) go func() { err := server.Run() require.NoError(t, err) @@ -2504,6 +2507,7 @@ func TestLocalhostClientMapping(t *testing.T) { server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) cli.port = getPortFromTCPAddr(server.listener.Addr()) go func() { err := server.Run() @@ -3121,6 +3125,7 @@ func TestProxyProtocolWithIpFallbackable(t *testing.T) { // Prepare Server server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) go func() { err := server.Run() require.NoError(t, err) @@ -3185,6 +3190,7 @@ func TestProxyProtocolWithIpNoFallbackable(t *testing.T) { // Prepare Server server, err := NewServer(cfg, ts.tidbdrv) require.NoError(t, err) + server.SetDomain(ts.domain) go func() { err := server.Run() require.NoError(t, err) diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index b27d23e3e1bb4..b0a3251103447 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -72,8 +72,8 @@ type Handle struct { // sysProcTracker is used to track sys process like analyze sysProcTracker sessionctx.SysProcTracker - // serverIDGetter is used to get server ID for generating auto analyze ID. - serverIDGetter func() uint64 + // autoAnalyzeProcIDGetter is used to generate auto analyze ID. + autoAnalyzeProcIDGetter func() uint64 InitStatsDone chan struct{} @@ -484,16 +484,16 @@ type sessionPool interface { } // NewHandle creates a Handle for update stats. -func NewHandle(ctx, initStatsCtx sessionctx.Context, lease time.Duration, pool sessionPool, tracker sessionctx.SysProcTracker, serverIDGetter func() uint64) (*Handle, error) { +func NewHandle(ctx, initStatsCtx sessionctx.Context, lease time.Duration, pool sessionPool, tracker sessionctx.SysProcTracker, autoAnalyzeProcIDGetter func() uint64) (*Handle, error) { cfg := config.GetGlobalConfig() handle := &Handle{ - ddlEventCh: make(chan *ddlUtil.Event, 1000), - listHead: &SessionStatsCollector{mapper: make(tableDeltaMap), rateMap: make(errorRateDeltaMap)}, - idxUsageListHead: &SessionIndexUsageCollector{mapper: make(indexUsageMap)}, - pool: pool, - sysProcTracker: tracker, - serverIDGetter: serverIDGetter, - InitStatsDone: make(chan struct{}), + ddlEventCh: make(chan *ddlUtil.Event, 1000), + listHead: &SessionStatsCollector{mapper: make(tableDeltaMap), rateMap: make(errorRateDeltaMap)}, + idxUsageListHead: &SessionIndexUsageCollector{mapper: make(indexUsageMap)}, + pool: pool, + sysProcTracker: tracker, + autoAnalyzeProcIDGetter: autoAnalyzeProcIDGetter, + InitStatsDone: make(chan struct{}), } handle.initStatsCtx = initStatsCtx handle.lease.Store(lease) diff --git a/statistics/handle/handletest/handle_test.go b/statistics/handle/handletest/handle_test.go index 627efc7a9167b..060b6e217134e 100644 --- a/statistics/handle/handletest/handle_test.go +++ b/statistics/handle/handletest/handle_test.go @@ -168,7 +168,7 @@ func TestVersion(t *testing.T) { tbl1, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t1")) require.NoError(t, err) tableInfo1 := tbl1.Meta() - h, err := handle.NewHandle(testKit.Session(), testKit2.Session(), time.Millisecond, do.SysSessionPool(), do.SysProcTracker(), do.ServerID) + h, err := handle.NewHandle(testKit.Session(), testKit2.Session(), time.Millisecond, do.SysSessionPool(), do.SysProcTracker(), do.GetAutoAnalyzeProcID) require.NoError(t, err) unit := oracle.ComposeTS(1, 0) testKit.MustExec("update mysql.stats_meta set version = ? where table_id = ?", 2*unit, tableInfo1.ID) diff --git a/statistics/handle/update.go b/statistics/handle/update.go index 9b7b713f2a15d..95ca2a959bac3 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -1296,7 +1296,7 @@ var execOptionForAnalyze = map[int]sqlexec.OptionFuncAlias{ func (h *Handle) execAutoAnalyze(statsVer int, analyzeSnapshot bool, sql string, params ...interface{}) { startTime := time.Now() - autoAnalyzeProcID := util.GetAutoAnalyzeProcID(h.serverIDGetter) + autoAnalyzeProcID := h.autoAnalyzeProcIDGetter() _, _, err := h.execRestrictedSQLWithStatsVer(context.Background(), statsVer, autoAnalyzeProcID, analyzeSnapshot, sql, params...) dur := time.Since(startTime) metrics.AutoAnalyzeHistogram.Observe(dur.Seconds()) diff --git a/testkit/mocksessionmanager.go b/testkit/mocksessionmanager.go index a5955b081601f..c4d7297ab9857 100644 --- a/testkit/mocksessionmanager.go +++ b/testkit/mocksessionmanager.go @@ -120,6 +120,11 @@ func (msm *MockSessionManager) ServerID() uint64 { return msm.SerID } +// GetAutoAnalyzeProcID implement SessionManager interface. +func (msm *MockSessionManager) GetAutoAnalyzeProcID() uint64 { + return uint64(1) +} + // StoreInternalSession is to store internal session. func (msm *MockSessionManager) StoreInternalSession(s interface{}) { msm.mu.Lock() diff --git a/tests/globalkilltest/.gitignore b/tests/globalkilltest/.gitignore index 90f1f9540935d..e69de29bb2d1d 100644 --- a/tests/globalkilltest/.gitignore +++ b/tests/globalkilltest/.gitignore @@ -1,2 +0,0 @@ -pd* -tikv* diff --git a/tests/globalkilltest/Dockerfile b/tests/globalkilltest/Dockerfile new file mode 100644 index 0000000000000..e433b2d2e6efe --- /dev/null +++ b/tests/globalkilltest/Dockerfile @@ -0,0 +1,37 @@ +# Copyright 2023 PingCAP, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM rockylinux:9 + +RUN dnf update -y && dnf groupinstall 'Development Tools' -y +RUN dnf install procps-ng mysql -y + +ENV GOLANG_VERSION 1.20.3 +ENV ARCH amd64 +ENV GOLANG_DOWNLOAD_URL https://dl.google.com/go/go$GOLANG_VERSION.linux-$ARCH.tar.gz +ENV GOPATH /go +ENV GOROOT /usr/local/go +ENV PATH $GOPATH/bin:$GOROOT/bin:$PATH +RUN curl -fsSL "$GOLANG_DOWNLOAD_URL" -o golang.tar.gz \ + && tar -C /usr/local -xzf golang.tar.gz \ + && rm golang.tar.gz + +RUN mkdir -p /go/src/github.com/pingcap/tidb +WORKDIR /go/src/github.com/pingcap/tidb + +COPY go.mod . +COPY go.sum . +COPY parser/go.mod parser/go.mod +COPY parser/go.sum parser/go.sum +RUN GO111MODULE=on go mod download diff --git a/tests/globalkilltest/README.md b/tests/globalkilltest/README.md index 49e949f56e45a..afeb7e5470c6e 100644 --- a/tests/globalkilltest/README.md +++ b/tests/globalkilltest/README.md @@ -24,7 +24,7 @@ Usage: ./run-tests.sh [options] --tidb_status_port : First TiDB server status listening port. port ~ port+2 will be used. Defaults to "8000". - --pd : PD client path, ip:port list seperated by comma. + --pd : PD client path, ip:port list separated by comma. Defaults to "127.0.0.1:2379". --pd_proxy_port : PD proxy port. PD proxy is used to simulate lost connection between TiDB and PD. @@ -47,8 +47,22 @@ Usage: ./run-tests.sh [options] ## Prerequisite 1. Build TiDB binary for test. See [Makefile](https://github.com/pingcap/tidb/blob/master/tests/globalkilltest/Makefile) for detail. -2. Establish a cluster with PD & TiKV, and provide PD client path by `--pd=ip:port[,ip:port]`. +2. Prepare `pd-server` and `tikv-server` to setup a cluster for tests. You can download the binaries by `TiUP` +```bash +cd tests/globalkilltest +mkdir -p bin +tiup install pd:nightly tikv:nightly +cp ~/.tiup/components/pd/$(ls ~/.tiup/components/pd | tail -1)/pd-server bin/ +cp ~/.tiup/components/tikv/$(ls ~/.tiup/components/tikv | tail -1)/tikv-server bin/ +``` + +Alternatively, if you have Docker environment, you can run `up.sh`, which will prepare binaries & run `make` for you: + +```sh +cd tests/globalkilltest +./up.sh +``` ## Test Scenarios @@ -83,11 +97,9 @@ In Integration Test after commit and before merge, run these commands under TiDB ```sh cd tests/globalkilltest make -./run-tests.sh --pd= +./run-tests.sh ``` -Again, before testing, establish a cluster with PD & TiKV and provide `pd client path` by `--pd=`. - ### Manual Test Run a single test manually (take `TestMultipleTiDB` as example): diff --git a/tests/globalkilltest/global_kill_test.go b/tests/globalkilltest/global_kill_test.go index d7732dcc83eca..7c9836fe6f269 100644 --- a/tests/globalkilltest/global_kill_test.go +++ b/tests/globalkilltest/global_kill_test.go @@ -130,12 +130,13 @@ func (s *GlobalKillSuite) connectPD() (cli *clientv3.Client, err error) { func (s *GlobalKillSuite) startTiKV(dataDir string) (err error) { s.tikvProc = exec.Command(*tikvBinaryPath, fmt.Sprintf("--pd=%s", *pdClientPath), - fmt.Sprintf("--data-dir=tikv-%s", dataDir), - "--addr=0.0.0.0:20160", - "--log-file=tikv.log", + fmt.Sprintf("--data-dir=%s/tikv-%s", *tmpPath, dataDir), + "--addr=127.0.0.1:20160", + fmt.Sprintf("--log-file=%s/tikv.log", *tmpPath), "--advertise-addr=127.0.0.1:20160", + "--config=tikv.toml", ) - log.Info("starting tikv") + log.Info("starting tikv", zap.Any("cmd", s.tikvProc)) err = s.tikvProc.Start() if err != nil { return errors.Trace(err) @@ -147,10 +148,10 @@ func (s *GlobalKillSuite) startTiKV(dataDir string) (err error) { func (s *GlobalKillSuite) startPD(dataDir string) (err error) { s.pdProc = exec.Command(*pdBinaryPath, "--name=pd", - "--log-file=pd.log", + fmt.Sprintf("--log-file=%s/pd.log", *tmpPath), fmt.Sprintf("--client-urls=http://%s", *pdClientPath), - fmt.Sprintf("--data-dir=pd-%s", dataDir)) - log.Info("starting pd") + fmt.Sprintf("--data-dir=%s/pd-%s", *tmpPath, dataDir)) + log.Info("starting pd", zap.Any("cmd", s.pdProc)) err = s.pdProc.Start() if err != nil { return errors.Trace(err) @@ -162,43 +163,43 @@ func (s *GlobalKillSuite) startPD(dataDir string) (err error) { func (s *GlobalKillSuite) startCluster() (err error) { err = s.startPD(s.clusterID) if err != nil { - return + return errors.Trace(err) } err = s.startTiKV(s.clusterID) if err != nil { - return + return errors.Trace(err) } time.Sleep(10 * time.Second) - return + return nil } func (s *GlobalKillSuite) stopPD() (err error) { if err = s.pdProc.Process.Kill(); err != nil { - return + return errors.Trace(err) } if err = s.pdProc.Wait(); err != nil && err.Error() != "signal: killed" { - return err + return errors.Trace(err) } return nil } func (s *GlobalKillSuite) stopTiKV() (err error) { if err = s.tikvProc.Process.Kill(); err != nil { - return + return errors.Trace(err) } if err = s.tikvProc.Wait(); err != nil && err.Error() != "signal: killed" { - return err + return errors.Trace(err) } return nil } func (s *GlobalKillSuite) cleanCluster() (err error) { if err = s.stopPD(); err != nil { - return err + return errors.Trace(err) } if err = s.stopTiKV(); err != nil { - return err + return errors.Trace(err) } log.Info("cluster cleaned") return nil @@ -212,6 +213,7 @@ func (s *GlobalKillSuite) startTiDBWithoutPD(port int, statusPort int) (cmd *exe fmt.Sprintf("-P=%d", port), fmt.Sprintf("--status=%d", statusPort), fmt.Sprintf("--log-file=%s/tidb%d.log", *tmpPath, port), + fmt.Sprintf("--log-slow-query=%s/tidb-slow%d.log", *tmpPath, port), fmt.Sprintf("--config=%s", "./config.toml")) log.Info("starting tidb", zap.Any("cmd", cmd)) err = cmd.Start() @@ -230,6 +232,7 @@ func (s *GlobalKillSuite) startTiDBWithPD(port int, statusPort int, pdPath strin fmt.Sprintf("-P=%d", port), fmt.Sprintf("--status=%d", statusPort), fmt.Sprintf("--log-file=%s/tidb%d.log", *tmpPath, port), + fmt.Sprintf("--log-slow-query=%s/tidb-slow%d.log", *tmpPath, port), fmt.Sprintf("--config=%s", "./config.toml")) log.Info("starting tidb", zap.Any("cmd", cmd)) err = cmd.Start() diff --git a/tests/globalkilltest/tikv.toml b/tests/globalkilltest/tikv.toml new file mode 100644 index 0000000000000..e6dd2afd33435 --- /dev/null +++ b/tests/globalkilltest/tikv.toml @@ -0,0 +1,2 @@ +[storage] +reserve-space = "1KB" diff --git a/tests/globalkilltest/up.sh b/tests/globalkilltest/up.sh new file mode 100755 index 0000000000000..61747f6d91786 --- /dev/null +++ b/tests/globalkilltest/up.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# Copyright 2023 PingCAP, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -euxo pipefail + +# Prepare pd-server & tikv-server +mkdir -p bin +tiup install pd:nightly tikv:nightly +cp ~/.tiup/components/pd/$(ls ~/.tiup/components/pd | tail -1)/pd-server bin/ +cp ~/.tiup/components/tikv/$(ls ~/.tiup/components/tikv | tail -1)/tikv-server bin/ + +cd ../.. +TIDB_PATH=$(pwd) + +docker build -t globalkilltest -f tests/globalkilltest/Dockerfile . +docker run --name globalkilltest -it --rm -v $TIDB_PATH:/tidb globalkilltest /bin/bash -c \ + 'git config --global --add safe.directory /tidb && cd /tidb/tests/globalkilltest && make && ./run-tests.sh' diff --git a/tidb-server/main.go b/tidb-server/main.go index 13570fa152042..fb887621a58ba 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -815,7 +815,6 @@ func createServer(storage kv.Storage, dom *domain.Domain) *server.Server { log.Fatal("failed to create the server", zap.Error(err), zap.Stack("stack")) } svr.SetDomain(dom) - svr.InitGlobalConnID(dom.ServerID) go dom.ExpensiveQueryHandle().SetSessionManager(svr).Run() go dom.MemoryUsageAlarmHandle().SetSessionManager(svr).Run() go dom.ServerMemoryLimitHandle().SetSessionManager(svr).Run() diff --git a/util/BUILD.bazel b/util/BUILD.bazel index 3e22b7a04eebc..f8fcd64a03e88 100644 --- a/util/BUILD.bazel +++ b/util/BUILD.bazel @@ -60,7 +60,6 @@ go_test( "main_test.go", "misc_test.go", "prefix_helper_test.go", - "processinfo_test.go", "security_test.go", "urls_test.go", "util_test.go", @@ -71,7 +70,6 @@ go_test( flaky = True, shard_count = 50, deps = [ - "//config", "//kv", "//parser", "//parser/model", diff --git a/util/expensivequery/expensivequery.go b/util/expensivequery/expensivequery.go index bbf1fd22eea25..280929ac75d75 100644 --- a/util/expensivequery/expensivequery.go +++ b/util/expensivequery/expensivequery.go @@ -71,7 +71,7 @@ func (eqh *Handle) Run() { zap.Duration("maxExecutionTime", time.Duration(info.MaxExecutionTime)*time.Millisecond), zap.String("processInfo", info.String())) sm.Kill(info.ID, true, true) } - if info.ID == util.GetAutoAnalyzeProcID(sm.ServerID) { + if info.ID == sm.GetAutoAnalyzeProcID() { maxAutoAnalyzeTime := variable.MaxAutoAnalyzeTime.Load() if maxAutoAnalyzeTime > 0 && costTime > time.Duration(maxAutoAnalyzeTime)*time.Second { logutil.BgLogger().Warn("auto analyze timeout, kill it", zap.Duration("costTime", costTime), diff --git a/util/globalconn/BUILD.bazel b/util/globalconn/BUILD.bazel index a330ab8ee85cf..fb99924ea9ffa 100644 --- a/util/globalconn/BUILD.bazel +++ b/util/globalconn/BUILD.bazel @@ -8,13 +8,21 @@ go_library( ], importpath = "github.com/pingcap/tidb/util/globalconn", visibility = ["//visibility:public"], - deps = ["@com_github_cznic_mathutil//:mathutil"], + deps = [ + "//util/logutil", + "@com_github_cznic_mathutil//:mathutil", + "@com_github_ngaut_sync2//:sync2", + "@org_uber_go_zap//:zap", + ], ) go_test( name = "globalconn_test", timeout = "short", - srcs = ["pool_test.go"], + srcs = [ + "globalconn_test.go", + "pool_test.go", + ], flaky = True, deps = [ ":globalconn", diff --git a/util/globalconn/globalconn.go b/util/globalconn/globalconn.go index 5e5ca9b3039dd..4cb46f81eedfc 100644 --- a/util/globalconn/globalconn.go +++ b/util/globalconn/globalconn.go @@ -1,4 +1,4 @@ -// Copyright 2022 PingCAP, Inc. +// Copyright 2023 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,7 +14,17 @@ package globalconn -// GlobalConnID is the global connection ID, providing UNIQUE connection IDs across the whole TiDB cluster. +import ( + "errors" + "fmt" + "math" + + "github.com/ngaut/sync2" + "github.com/pingcap/tidb/util/logutil" + "go.uber.org/zap" +) + +// GCID is the Global Connection ID, providing UNIQUE connection IDs across the whole TiDB cluster. // Used when GlobalKill feature is enable. // See https://github.com/pingcap/tidb/blob/master/docs/design/2020-06-01-global-kill.md // 32 bits version: @@ -32,6 +42,12 @@ package globalconn // | | serverId | local connId |markup| // |=0| (22b) | (40b) | =1 | // +--+---------------------+--------------------------------------+------+ +type GCID struct { + ServerID uint64 + LocalConnID uint64 + Is64bits bool +} + const ( // MaxServerID32 is maximum serverID for 32bits global connection ID. MaxServerID32 = 1<<11 - 1 @@ -46,4 +62,231 @@ const ( LocalConnIDBits64 = 40 // MaxLocalConnID64 is maximum localConnID for 64bits global connection ID. MaxLocalConnID64 = 1< MaxLocalConnID64 { + panic(fmt.Sprintf("unexpected localConnID %d exceeds %d", g.LocalConnID, MaxLocalConnID64)) + } + if g.ServerID > MaxServerID64 { + panic(fmt.Sprintf("unexpected serverID %d exceeds %d", g.ServerID, MaxServerID64)) + } + + id |= 0x1 + id |= g.LocalConnID << 1 // 40 bits local connID. + id |= g.ServerID << 41 // 22 bits serverID. + } else { + if g.LocalConnID > MaxLocalConnID32 { + panic(fmt.Sprintf("unexpected localConnID %d exceeds %d", g.LocalConnID, MaxLocalConnID32)) + } + if g.ServerID > MaxServerID32 { + panic(fmt.Sprintf("unexpected serverID %d exceeds %d", g.ServerID, MaxServerID32)) + } + + id |= g.LocalConnID << 1 // 20 bits local connID. + id |= g.ServerID << 21 // 11 bits serverID. + } + return id +} + +// ParseConnID parses an uint64 connection ID to GlobalConnID. +// +// `isTruncated` indicates that older versions of the client truncated the 64-bit GlobalConnID to 32-bit. +func ParseConnID(id uint64) (g GCID, isTruncated bool, err error) { + if id&0x80000000_00000000 > 0 { + return GCID{}, false, errors.New("unexpected connectionID exceeds int64") + } + if id&0x1 > 0 { // 64bits + if id&0xffffffff_00000000 == 0 { + return GCID{}, true, nil + } + return GCID{ + Is64bits: true, + LocalConnID: (id >> 1) & MaxLocalConnID64, + ServerID: (id >> 41) & MaxServerID64, + }, false, nil + } + + // 32bits + if id&0xffffffff_00000000 > 0 { + return GCID{}, false, errors.New("unexpected connectionID exceeds uint32") + } + return GCID{ + Is64bits: false, + LocalConnID: (id >> 1) & MaxLocalConnID32, + ServerID: (id >> 21) & MaxServerID32, + }, false, nil +} + +///////////////////////////////// Class Diagram /////////////////////////////////// +// // +// +----------+ +-----------------+ +-----------------------+ // +// | Server | ---> | ConnIDAllocator | <<--+-- | GlobalConnIDAllocator | --+ // +// +----------+ +-----------------+ | +-----------------------+ | // +// +-- | SimpleConnIDAllocator | | // +// +----------+------------+ | // +// | | // +// V | // +// +--------+ +----------------------+ | // +// | IDPool | <<--+-- | AutoIncPool | <--+ // +// +--------+ | +----------------------+ | // +// +-- | LockFreeCircularPool | <--+ // +// +----------------------+ // +// // +/////////////////////////////////////////////////////////////////////////////////// + +type serverIDGetterFn func() uint64 + +// Allocator allocates global connection IDs. +type Allocator interface { + // NextID returns next connection ID. + NextID() uint64 + // Release releases connection ID to allocator. + Release(connectionID uint64) + // GetReservedConnID returns reserved connection ID. + GetReservedConnID(reservedNo uint64) uint64 +} + +var ( + _ Allocator = (*SimpleAllocator)(nil) + _ Allocator = (*GlobalAllocator)(nil) ) + +// SimpleAllocator is a simple connection id allocator used when GlobalKill feature is disable. +type SimpleAllocator struct { + pool AutoIncPool +} + +// NewSimpleAllocator creates a new SimpleAllocator. +func NewSimpleAllocator() *SimpleAllocator { + a := &SimpleAllocator{} + a.pool.Init(math.MaxUint64 - ReservedCount) + return a +} + +// NextID implements ConnIDAllocator interface. +func (a *SimpleAllocator) NextID() uint64 { + id, _ := a.pool.Get() + return id +} + +// Release implements ConnIDAllocator interface. +func (a *SimpleAllocator) Release(id uint64) { + a.pool.Put(id) +} + +// GetReservedConnID implements ConnIDAllocator interface. +func (*SimpleAllocator) GetReservedConnID(reservedNo uint64) uint64 { + if reservedNo >= ReservedCount { + panic("invalid reservedNo exceed ReservedCount") + } + return math.MaxUint64 - reservedNo +} + +// GlobalAllocator is global connection ID allocator. +type GlobalAllocator struct { + is64bits sync2.AtomicInt32 // !0: true, 0: false + serverIDGetter func() uint64 + + local32 LockFreeCircularPool + local64 AutoIncPool +} + +// Is64 indicates allocate 64bits global connection ID or not. +func (g *GlobalAllocator) Is64() bool { + return g.is64bits.Get() != 0 +} + +// UpgradeTo64 upgrade allocator to 64bits. +func (g *GlobalAllocator) UpgradeTo64() { + g.is64bits.Set(1) +} + +// LocalConnIDAllocator64TryCount is the try count of 64bits local connID allocation. +const LocalConnIDAllocator64TryCount = 10 + +// NewGlobalAllocator creates a GlobalAllocator. +func NewGlobalAllocator(serverIDGetter serverIDGetterFn) *GlobalAllocator { + g := &GlobalAllocator{ + serverIDGetter: serverIDGetter, + } + g.local32.InitExt(1<= ReservedCount { + panic("invalid reservedNo exceed ReservedCount") + } + + serverID := g.serverIDGetter() + globalConnID := GCID{ + ServerID: serverID, + LocalConnID: (1 << LocalConnIDBits64) - 1 - reservedNo, + Is64bits: true, + } + return globalConnID.ToConnID() +} + +// Allocate allocates a new global connection ID. +func (g *GlobalAllocator) Allocate() GCID { + serverID := g.serverIDGetter() + + // 32bits. + if !g.Is64() { + localConnID, ok := g.local32.Get() + if ok { + return GCID{ + ServerID: serverID, + LocalConnID: localConnID, + Is64bits: false, + } + } + g.UpgradeTo64() // go on to 64bits. + } + + // 64bits. + localConnID, ok := g.local64.Get() + if !ok { + // local connID with 40bits pool size is big enough and should not be exhausted, as `MaxServerConnections` is no more than math.MaxUint32. + panic(fmt.Sprintf("Failed to allocate 64bits local connID after try %v times. Should never happen", LocalConnIDAllocator64TryCount)) + } + return GCID{ + ServerID: serverID, + LocalConnID: localConnID, + Is64bits: true, + } +} + +// Release releases connectionID to pool. +func (g *GlobalAllocator) Release(connectionID uint64) { + globalConnID, isTruncated, err := ParseConnID(connectionID) + if err != nil || isTruncated { + logutil.BgLogger().Error("failed to ParseGlobalConnID", zap.Error(err), zap.Uint64("connectionID", connectionID), zap.Bool("isTruncated", isTruncated)) + return + } + + if globalConnID.Is64bits { + g.local64.Put(globalConnID.LocalConnID) + } else { + if ok := g.local32.Put(globalConnID.LocalConnID); !ok { + logutil.BgLogger().Error("failed to release 32bits connection ID", zap.Uint64("connectionID", connectionID), zap.Uint64("localConnID", globalConnID.LocalConnID)) + } + } +} diff --git a/util/globalconn/globalconn_test.go b/util/globalconn/globalconn_test.go new file mode 100644 index 0000000000000..592e83e57707a --- /dev/null +++ b/util/globalconn/globalconn_test.go @@ -0,0 +1,229 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package globalconn_test + +import ( + "fmt" + "math" + "runtime" + "testing" + + "github.com/pingcap/tidb/util/globalconn" + "github.com/stretchr/testify/assert" +) + +func TestToConnID(t *testing.T) { + assert := assert.New(t) + + type Case struct { + gcid globalconn.GCID + shouldPanic bool + expected uint64 + } + + cases := []Case{ + { + gcid: globalconn.GCID{ + Is64bits: true, + ServerID: 1001, + LocalConnID: 123, + }, + shouldPanic: false, + expected: (uint64(1001) << 41) | (uint64(123) << 1) | 1, + }, + { + gcid: globalconn.GCID{ + Is64bits: true, + ServerID: 1 << 22, + LocalConnID: 123, + }, + shouldPanic: true, + expected: 0, + }, + { + gcid: globalconn.GCID{ + Is64bits: true, + ServerID: 1001, + LocalConnID: 1 << 40, + }, + shouldPanic: true, + expected: 0, + }, + { + gcid: globalconn.GCID{ + Is64bits: false, + ServerID: 1001, + LocalConnID: 123, + }, + shouldPanic: false, + expected: (uint64(1001) << 21) | (uint64(123) << 1), + }, + { + gcid: globalconn.GCID{ + Is64bits: false, + ServerID: 1 << 11, + LocalConnID: 123, + }, + shouldPanic: true, + expected: 0, + }, + { + gcid: globalconn.GCID{ + Is64bits: false, + ServerID: 1001, + LocalConnID: 1 << 20, + }, + shouldPanic: true, + expected: 0, + }, + } + + for _, c := range cases { + if c.shouldPanic { + assert.Panics(func() { + c.gcid.ToConnID() + }) + } else { + assert.Equal(c.expected, c.gcid.ToConnID()) + } + } +} + +func TestGlobalConnID(t *testing.T) { + assert := assert.New(t) + var ( + err error + isTruncated bool + ) + + // exceeds int64 + _, _, err = globalconn.ParseConnID(0x80000000_00000321) + assert.NotNil(err) + + // 64bits truncated + _, isTruncated, err = globalconn.ParseConnID(101) + assert.Nil(err) + assert.True(isTruncated) + + // 64bits + id1 := (uint64(1001) << 41) | (uint64(123) << 1) | 1 + gcid1, isTruncated, err := globalconn.ParseConnID(id1) + assert.Nil(err) + assert.False(isTruncated) + assert.Equal(uint64(1001), gcid1.ServerID) + assert.Equal(uint64(123), gcid1.LocalConnID) + assert.True(gcid1.Is64bits) + + // exceeds uint32 + _, _, err = globalconn.ParseConnID(0x1_00000320) + assert.NotNil(err) + + // 32bits + id2 := (uint64(2002) << 21) | (uint64(321) << 1) + gcid2, isTruncated, err := globalconn.ParseConnID(id2) + assert.Nil(err) + assert.False(isTruncated) + assert.Equal(uint64(2002), gcid2.ServerID) + assert.Equal(uint64(321), gcid2.LocalConnID) + assert.False(gcid2.Is64bits) + assert.Equal(gcid2.ToConnID(), id2) +} + +func TestGetReservedConnID(t *testing.T) { + assert := assert.New(t) + + simpleAlloc := globalconn.NewSimpleAllocator() + assert.Equal(math.MaxUint64-uint64(0), simpleAlloc.GetReservedConnID(0)) + assert.Equal(math.MaxUint64-uint64(1), simpleAlloc.GetReservedConnID(1)) + + serverID := func() uint64 { + return 1001 + } + + globalAlloc := globalconn.NewGlobalAllocator(serverID) + var maxLocalConnID uint64 = 1<<40 - 1 + assert.Equal(uint64(1001)<<41|(maxLocalConnID)<<1|1, globalAlloc.GetReservedConnID(0)) + assert.Equal(uint64(1001)<<41|(maxLocalConnID-1)<<1|1, globalAlloc.GetReservedConnID(1)) +} + +func benchmarkLocalConnIDAllocator32(b *testing.B, pool globalconn.IDPool) { + var ( + id uint64 + ok bool + ) + + // allocate local conn ID. + for { + if id, ok = pool.Get(); ok { + break + } + runtime.Gosched() + } + + // deallocate local conn ID. + if ok = pool.Put(id); !ok { + b.Fatal("pool unexpected full") + } +} + +func BenchmarkLocalConnIDAllocator(b *testing.B) { + b.ReportAllocs() + + concurrencyCases := []int{1, 3, 10, 20, 100} + for _, concurrency := range concurrencyCases { + b.Run(fmt.Sprintf("Allocator 64 x%v", concurrency), func(b *testing.B) { + pool := globalconn.AutoIncPool{} + pool.InitExt(1< 0 { pool.InitForTest(headPos, fillCount) } diff --git a/util/processinfo.go b/util/processinfo.go index 2d70c1f1cf442..616682d5deb69 100644 --- a/util/processinfo.go +++ b/util/processinfo.go @@ -16,11 +16,9 @@ package util import ( "crypto/tls" - "errors" "fmt" "net" "strings" - "sync/atomic" "time" "github.com/pingcap/tidb/parser/mysql" @@ -189,6 +187,8 @@ type SessionManager interface { KillAllConnections() UpdateTLSConfig(cfg *tls.Config) ServerID() uint64 + // GetAutoAnalyzeProcID returns processID for auto analyze + GetAutoAnalyzeProcID() uint64 // StoreInternalSession puts the internal session pointer to the map in the SessionManager. StoreInternalSession(se interface{}) // DeleteInternalSession deletes the internal session pointer from the map in the SessionManager. @@ -200,109 +200,3 @@ type SessionManager interface { // KillNonFlashbackClusterConn kill all non flashback cluster connections. KillNonFlashbackClusterConn() } - -// GlobalConnID is the global connection ID, providing UNIQUE connection IDs across the whole TiDB cluster. -// 64 bits version: -/* - 63 62 41 40 1 0 - +--+---------------------+--------------------------------------+------+ - | | serverId | local connId |markup| - |=0| (22b) | (40b) | =1 | - +--+---------------------+--------------------------------------+------+ - 32 bits version(coming soon): - 31 1 0 - +-----------------------------+------+ - | ??? |markup| - | ??? | =0 | - +-----------------------------+------+ -*/ -type GlobalConnID struct { - ServerIDGetter func() uint64 - ServerID uint64 - LocalConnID uint64 - Is64bits bool -} - -// NewGlobalConnID creates GlobalConnID with serverID -func NewGlobalConnID(serverID uint64, is64Bits bool) GlobalConnID { - return GlobalConnID{ServerID: serverID, Is64bits: is64Bits, LocalConnID: reservedLocalConns} -} - -// NewGlobalConnIDWithGetter creates GlobalConnID with serverIDGetter -func NewGlobalConnIDWithGetter(serverIDGetter func() uint64, is64Bits bool) GlobalConnID { - return GlobalConnID{ServerIDGetter: serverIDGetter, Is64bits: is64Bits, LocalConnID: reservedLocalConns} -} - -const ( - // MaxServerID is maximum serverID. - MaxServerID = 1<<22 - 1 -) - -func (g *GlobalConnID) makeID(localConnID uint64) uint64 { - var ( - id uint64 - serverID uint64 - ) - if g.ServerIDGetter != nil { - serverID = g.ServerIDGetter() - } else { - serverID = g.ServerID - } - if g.Is64bits { - id |= 0x1 - id |= localConnID & 0xff_ffff_ffff << 1 // 40 bits local connID. - id |= serverID & MaxServerID << 41 // 22 bits serverID. - } else { - // TODO: update after new design for 32 bits version. - id |= localConnID & 0x7fff_ffff << 1 // 31 bits local connID. - } - return id -} - -// ID returns the connection id -func (g *GlobalConnID) ID() uint64 { - return g.makeID(g.LocalConnID) -} - -// NextID returns next connection id -func (g *GlobalConnID) NextID() uint64 { - localConnID := atomic.AddUint64(&g.LocalConnID, 1) - return g.makeID(localConnID) -} - -// ParseGlobalConnID parses an uint64 to GlobalConnID. -// -// `isTruncated` indicates that older versions of the client truncated the 64-bit GlobalConnID to 32-bit. -func ParseGlobalConnID(id uint64) (g GlobalConnID, isTruncated bool, err error) { - if id&0x80000000_00000000 > 0 { - return GlobalConnID{}, false, errors.New("Unexpected connectionID excceeds int64") - } - if id&0x1 > 0 { - if id&0xffffffff_00000000 == 0 { - return GlobalConnID{}, true, nil - } - return GlobalConnID{ - Is64bits: true, - LocalConnID: (id >> 1) & 0xff_ffff_ffff, - ServerID: (id >> 41) & MaxServerID, - }, false, nil - } - // TODO: update after new design for 32 bits version. - return GlobalConnID{ - Is64bits: false, - LocalConnID: (id >> 1) & 0x7fff_ffff, - ServerID: 0, - }, false, nil -} - -const ( - reservedLocalConns = 200 - reservedConnAnalyze = 1 -) - -// GetAutoAnalyzeProcID returns processID for auto analyze -// TODO support IDs for concurrent auto-analyze -func GetAutoAnalyzeProcID(serverIDGetter func() uint64) uint64 { - globalConnID := NewGlobalConnIDWithGetter(serverIDGetter, true) - return globalConnID.makeID(reservedConnAnalyze) -} diff --git a/util/processinfo_test.go b/util/processinfo_test.go deleted file mode 100644 index 2f14020325794..0000000000000 --- a/util/processinfo_test.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package util_test - -import ( - "testing" - - "github.com/pingcap/tidb/config" - "github.com/pingcap/tidb/util" - "github.com/stretchr/testify/assert" -) - -func TestGlobalConnID(t *testing.T) { - originCfg := config.GetGlobalConfig() - newCfg := *originCfg - newCfg.EnableGlobalKill = true - config.StoreGlobalConfig(&newCfg) - defer func() { - config.StoreGlobalConfig(originCfg) - }() - connID := util.GlobalConnID{ - Is64bits: true, - ServerID: 1001, - LocalConnID: 123, - } - assert.Equal(t, (uint64(1001)<<41)|(uint64(123)<<1)|1, connID.ID()) - - next := connID.NextID() - assert.Equal(t, (uint64(1001)<<41)|(uint64(124)<<1)|1, next) - - connID1, isTruncated, err := util.ParseGlobalConnID(next) - assert.Nil(t, err) - assert.False(t, isTruncated) - assert.Equal(t, uint64(1001), connID1.ServerID) - assert.Equal(t, uint64(124), connID1.LocalConnID) - assert.True(t, connID1.Is64bits) - - _, isTruncated, err = util.ParseGlobalConnID(101) - assert.Nil(t, err) - assert.True(t, isTruncated) - - _, _, err = util.ParseGlobalConnID(0x80000000_00000321) - assert.NotNil(t, err) - - connID2 := util.GlobalConnID{ - Is64bits: true, - ServerIDGetter: func() uint64 { return 2002 }, - LocalConnID: 123, - } - assert.Equal(t, (uint64(2002)<<41)|(uint64(123)<<1)|1, connID2.ID()) -}