From d240f00a33cc329e35f28cac095b6428a075f98b Mon Sep 17 00:00:00 2001 From: Chunzhu Li Date: Tue, 1 Aug 2023 19:22:08 +0800 Subject: [PATCH] lightning: fix pd http request using old address (#45680) close pingcap/tidb#43436 --- br/pkg/lightning/backend/local/local.go | 12 ++++---- br/pkg/lightning/backend/local/tikv_mode.go | 9 +++--- br/pkg/lightning/importer/BUILD.bazel | 2 ++ br/pkg/lightning/importer/checksum_helper.go | 12 ++------ br/pkg/lightning/importer/get_pre_info.go | 15 +++++++--- .../lightning/importer/get_pre_info_test.go | 5 +++- br/pkg/lightning/importer/import.go | 19 +++++++++--- br/pkg/lightning/importer/precheck.go | 6 ++-- .../lightning/importer/table_import_test.go | 29 ++++++++++++++++--- disttask/importinto/dispatcher.go | 6 ++-- executor/importer/BUILD.bazel | 1 + executor/importer/table_import.go | 15 +++++++--- tests/realtikvtest/importintotest/BUILD.bazel | 1 + .../importintotest/import_into_test.go | 3 +- 14 files changed, 93 insertions(+), 42 deletions(-) diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index 5a8217119a485..7c6609e1720c4 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -268,15 +268,15 @@ func (*encodingBuilder) MakeEmptyRows() encode.Rows { type targetInfoGetter struct { tls *common.TLS targetDB *sql.DB - pdAddr string + pdCli pd.Client } // NewTargetInfoGetter creates an TargetInfoGetter with local backend implementation. -func NewTargetInfoGetter(tls *common.TLS, db *sql.DB, pdAddr string) backend.TargetInfoGetter { +func NewTargetInfoGetter(tls *common.TLS, db *sql.DB, pdCli pd.Client) backend.TargetInfoGetter { return &targetInfoGetter{ tls: tls, targetDB: db, - pdAddr: pdAddr, + pdCli: pdCli, } } @@ -297,10 +297,10 @@ func (g *targetInfoGetter) CheckRequirements(ctx context.Context, checkCtx *back if err := checkTiDBVersion(ctx, versionStr, localMinTiDBVersion, localMaxTiDBVersion); err != nil { return err } - if err := tikv.CheckPDVersion(ctx, g.tls, g.pdAddr, localMinPDVersion, localMaxPDVersion); err != nil { + if err := tikv.CheckPDVersion(ctx, g.tls, g.pdCli.GetLeaderAddr(), localMinPDVersion, localMaxPDVersion); err != nil { return err } - if err := tikv.CheckTiKVVersion(ctx, g.tls, g.pdAddr, localMinTiKVVersion, localMaxTiKVVersion); err != nil { + if err := tikv.CheckTiKVVersion(ctx, g.tls, g.pdCli.GetLeaderAddr(), localMinTiKVVersion, localMaxTiKVVersion); err != nil { return err } @@ -1719,7 +1719,7 @@ func (local *Backend) LocalWriter(_ context.Context, cfg *backend.LocalWriterCon // This function will spawn a goroutine to keep switch mode periodically until the context is done. // The return done channel is used to notify the caller that the background goroutine is exited. func (local *Backend) SwitchModeByKeyRanges(ctx context.Context, ranges []Range) (<-chan struct{}, error) { - switcher := NewTiKVModeSwitcher(local.tls, local.PDAddr, log.FromContext(ctx).Logger) + switcher := NewTiKVModeSwitcher(local.tls, local.pdCtl.GetPDClient(), log.FromContext(ctx).Logger) done := make(chan struct{}) keyRanges := make([]*sst.Range, 0, len(ranges)) diff --git a/br/pkg/lightning/backend/local/tikv_mode.go b/br/pkg/lightning/backend/local/tikv_mode.go index 0abdc69393381..69345e58a8f4e 100644 --- a/br/pkg/lightning/backend/local/tikv_mode.go +++ b/br/pkg/lightning/backend/local/tikv_mode.go @@ -20,6 +20,7 @@ import ( sstpb "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/tikv" + pd "github.com/tikv/pd/client" "go.uber.org/zap" ) @@ -34,15 +35,15 @@ type TiKVModeSwitcher interface { // TiKVModeSwitcher is used to switch TiKV nodes between Import and Normal mode. type switcher struct { tls *common.TLS - pdAddr string + pdCli pd.Client logger *zap.Logger } // NewTiKVModeSwitcher creates a new TiKVModeSwitcher. -func NewTiKVModeSwitcher(tls *common.TLS, pdAddr string, logger *zap.Logger) TiKVModeSwitcher { +func NewTiKVModeSwitcher(tls *common.TLS, pdCli pd.Client, logger *zap.Logger) TiKVModeSwitcher { return &switcher{ tls: tls, - pdAddr: pdAddr, + pdCli: pdCli, logger: logger, } } @@ -68,7 +69,7 @@ func (rc *switcher) switchTiKVMode(ctx context.Context, mode sstpb.SwitchMode, r } else { minState = tikv.StoreStateDisconnected } - tls := rc.tls.WithHost(rc.pdAddr) + tls := rc.tls.WithHost(rc.pdCli.GetLeaderAddr()) // we ignore switch mode failure since it is not fatal. // no need log the error, it is done in kv.SwitchMode already. _ = tikv.ForAllStores( diff --git a/br/pkg/lightning/importer/BUILD.bazel b/br/pkg/lightning/importer/BUILD.bazel index 6a1ffb91e7076..ca7cd11a97f7e 100644 --- a/br/pkg/lightning/importer/BUILD.bazel +++ b/br/pkg/lightning/importer/BUILD.bazel @@ -175,6 +175,8 @@ go_test( "@com_github_stretchr_testify//require", "@com_github_stretchr_testify//suite", "@com_github_tikv_client_go_v2//config", + "@com_github_tikv_client_go_v2//testutils", + "@com_github_tikv_pd_client//:client", "@com_github_xitongsys_parquet_go//writer", "@com_github_xitongsys_parquet_go_source//buffer", "@io_etcd_go_etcd_client_v3//:client", diff --git a/br/pkg/lightning/importer/checksum_helper.go b/br/pkg/lightning/importer/checksum_helper.go index b77116432910f..1809efac8594f 100644 --- a/br/pkg/lightning/importer/checksum_helper.go +++ b/br/pkg/lightning/importer/checksum_helper.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/metric" "github.com/pingcap/tidb/br/pkg/pdutil" "github.com/pingcap/tidb/kv" - pd "github.com/tikv/pd/client" "go.uber.org/zap" ) @@ -37,8 +36,7 @@ func NewChecksumManager(ctx context.Context, rc *Controller, store kv.Storage) ( return nil, nil } - pdAddr := rc.cfg.TiDB.PdAddr - pdVersion, err := pdutil.FetchPDVersion(ctx, rc.tls, pdAddr) + pdVersion, err := pdutil.FetchPDVersion(ctx, rc.tls, rc.pdCli.GetLeaderAddr()) if err != nil { return nil, errors.Trace(err) } @@ -46,12 +44,6 @@ func NewChecksumManager(ctx context.Context, rc *Controller, store kv.Storage) ( // for v4.0.0 or upper, we can use the gc ttl api var manager local.ChecksumManager if pdVersion.Major >= 4 && !rc.cfg.PostRestore.ChecksumViaSQL { - tlsOpt := rc.tls.ToPDSecurityOption() - pdCli, err := pd.NewClientWithContext(ctx, []string{pdAddr}, tlsOpt) - if err != nil { - return nil, errors.Trace(err) - } - backoffWeight, err := common.GetBackoffWeightFromDB(ctx, rc.db) // only set backoff weight when it's smaller than default value if err == nil && backoffWeight >= local.DefaultBackoffWeight { @@ -66,7 +58,7 @@ func NewChecksumManager(ctx context.Context, rc *Controller, store kv.Storage) ( log.FromContext(ctx).Warn("get tidb_request_source_type failed", zap.Error(err), zap.String("tidb_request_source_type", explicitRequestSourceType)) return nil, errors.Trace(err) } - manager = local.NewTiKVChecksumManager(store.GetClient(), pdCli, uint(rc.cfg.TiDB.DistSQLScanConcurrency), backoffWeight, rc.resourceGroupName, explicitRequestSourceType) + manager = local.NewTiKVChecksumManager(store.GetClient(), rc.pdCli, uint(rc.cfg.TiDB.DistSQLScanConcurrency), backoffWeight, rc.resourceGroupName, explicitRequestSourceType) } else { manager = local.NewTiDBChecksumExecutor(rc.db) } diff --git a/br/pkg/lightning/importer/get_pre_info.go b/br/pkg/lightning/importer/get_pre_info.go index 8ad3937faa19a..191fed628bdf6 100644 --- a/br/pkg/lightning/importer/get_pre_info.go +++ b/br/pkg/lightning/importer/get_pre_info.go @@ -50,6 +50,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/dbterror" "github.com/pingcap/tidb/util/mock" + pd "github.com/tikv/pd/client" "go.uber.org/zap" "golang.org/x/exp/maps" ) @@ -123,12 +124,14 @@ type TargetInfoGetterImpl struct { db *sql.DB tls *common.TLS backend backend.TargetInfoGetter + pdCli pd.Client } // NewTargetInfoGetterImpl creates a TargetInfoGetterImpl object. func NewTargetInfoGetterImpl( cfg *config.Config, targetDB *sql.DB, + pdCli pd.Client, ) (*TargetInfoGetterImpl, error) { tls, err := cfg.ToTLS() if err != nil { @@ -139,7 +142,10 @@ func NewTargetInfoGetterImpl( case config.BackendTiDB: backendTargetInfoGetter = tidb.NewTargetInfoGetter(targetDB) case config.BackendLocal: - backendTargetInfoGetter = local.NewTargetInfoGetter(tls, targetDB, cfg.TiDB.PdAddr) + if pdCli == nil { + return nil, common.ErrUnknown.GenWithStack("pd client is required when using local backend") + } + backendTargetInfoGetter = local.NewTargetInfoGetter(tls, targetDB, pdCli) default: return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend) } @@ -148,6 +154,7 @@ func NewTargetInfoGetterImpl( tls: tls, db: targetDB, backend: backendTargetInfoGetter, + pdCli: pdCli, }, nil } @@ -229,7 +236,7 @@ func (g *TargetInfoGetterImpl) GetTargetSysVariablesForImport(ctx context.Contex // It uses the PD interface through TLS to get the information. func (g *TargetInfoGetterImpl) GetReplicationConfig(ctx context.Context) (*pdtypes.ReplicationConfig, error) { result := new(pdtypes.ReplicationConfig) - if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdReplicate, &result); err != nil { + if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdReplicate, &result); err != nil { return nil, errors.Trace(err) } return result, nil @@ -240,7 +247,7 @@ func (g *TargetInfoGetterImpl) GetReplicationConfig(ctx context.Context) (*pdtyp // It uses the PD interface through TLS to get the information. func (g *TargetInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdtypes.StoresInfo, error) { result := new(pdtypes.StoresInfo) - if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdStores, result); err != nil { + if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdStores, result); err != nil { return nil, errors.Trace(err) } return result, nil @@ -251,7 +258,7 @@ func (g *TargetInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdtypes.Sto // It uses the PD interface through TLS to get the information. func (g *TargetInfoGetterImpl) GetEmptyRegionsInfo(ctx context.Context) (*pdtypes.RegionsInfo, error) { result := new(pdtypes.RegionsInfo) - if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdEmptyRegions, &result); err != nil { + if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdEmptyRegions, &result); err != nil { return nil, errors.Trace(err) } return result, nil diff --git a/br/pkg/lightning/importer/get_pre_info_test.go b/br/pkg/lightning/importer/get_pre_info_test.go index 66480654cdfd8..7fda215beaa85 100644 --- a/br/pkg/lightning/importer/get_pre_info_test.go +++ b/br/pkg/lightning/importer/get_pre_info_test.go @@ -757,7 +757,10 @@ func TestGetPreInfoIsTableEmpty(t *testing.T) { require.NoError(t, err) lnConfig := config.NewConfig() lnConfig.TikvImporter.Backend = config.BackendLocal - targetGetter, err := NewTargetInfoGetterImpl(lnConfig, db) + _, err = NewTargetInfoGetterImpl(lnConfig, db, nil) + require.ErrorContains(t, err, "pd client is required when using local backend") + lnConfig.TikvImporter.Backend = config.BackendTiDB + targetGetter, err := NewTargetInfoGetterImpl(lnConfig, db, nil) require.NoError(t, err) require.Equal(t, lnConfig, targetGetter.cfg) diff --git a/br/pkg/lightning/importer/import.go b/br/pkg/lightning/importer/import.go index 4fdfbde99fbb4..0b0e9f677a384 100644 --- a/br/pkg/lightning/importer/import.go +++ b/br/pkg/lightning/importer/import.go @@ -200,6 +200,7 @@ type Controller struct { engineMgr backend.EngineManager backend backend.Backend db *sql.DB + pdCli pd.Client alterTableLock sync.Mutex sysVars map[string]string @@ -333,6 +334,7 @@ func NewImportControllerWithPauser( var encodingBuilder encode.EncodingBuilder var backendObj backend.Backend + var pdCli pd.Client switch cfg.TikvImporter.Backend { case config.BackendTiDB: encodingBuilder = tidb.NewEncodingBuilder() @@ -348,9 +350,13 @@ func NewImportControllerWithPauser( if maxOpenFiles < 0 { maxOpenFiles = math.MaxInt32 } + pdCli, err = pd.NewClientWithContext(ctx, []string{cfg.TiDB.PdAddr}, tls.ToPDSecurityOption()) + if err != nil { + return nil, errors.Trace(err) + } if cfg.TikvImporter.DuplicateResolution != config.DupeResAlgNone { - if err := tikv.CheckTiKVVersion(ctx, tls, cfg.TiDB.PdAddr, minTiKVVersionForDuplicateResolution, maxTiKVVersionForDuplicateResolution); err != nil { + if err := tikv.CheckTiKVVersion(ctx, tls, pdCli.GetLeaderAddr(), minTiKVVersionForDuplicateResolution, maxTiKVVersionForDuplicateResolution); err != nil { if !berrors.Is(err, berrors.ErrVersionMismatch) { return nil, common.ErrCheckKVVersion.Wrap(err).GenWithStackByArgs() } @@ -419,7 +425,7 @@ func NewImportControllerWithPauser( var wrapper backend.TargetInfoGetter if cfg.TikvImporter.Backend == config.BackendLocal { - wrapper = local.NewTargetInfoGetter(tls, db, cfg.TiDB.PdAddr) + wrapper = local.NewTargetInfoGetter(tls, db, pdCli) } else { wrapper = tidb.NewTargetInfoGetter(db) } @@ -429,6 +435,7 @@ func NewImportControllerWithPauser( db: db, tls: tls, backend: wrapper, + pdCli: pdCli, } preInfoGetter, err := NewPreImportInfoGetter( cfg, @@ -458,6 +465,7 @@ func NewImportControllerWithPauser( pauser: p.Pauser, engineMgr: backend.MakeEngineManager(backendObj), backend: backendObj, + pdCli: pdCli, db: db, sysVars: common.DefaultImportantVariables, tls: tls, @@ -482,7 +490,7 @@ func NewImportControllerWithPauser( preInfoGetter: preInfoGetter, precheckItemBuilder: preCheckBuilder, encBuilder: encodingBuilder, - tikvModeSwitcher: local.NewTiKVModeSwitcher(tls, cfg.TiDB.PdAddr, log.FromContext(ctx).Logger), + tikvModeSwitcher: local.NewTiKVModeSwitcher(tls, pdCli, log.FromContext(ctx).Logger), keyspaceName: p.KeyspaceName, resourceGroupName: p.ResourceGroupName, @@ -495,6 +503,9 @@ func NewImportControllerWithPauser( func (rc *Controller) Close() { rc.backend.Close() _ = rc.db.Close() + if rc.pdCli != nil { + rc.pdCli.Close() + } } // Run starts the restore task. @@ -1870,7 +1881,7 @@ func (rc *Controller) fullCompact(ctx context.Context) error { } func (rc *Controller) doCompact(ctx context.Context, level int32) error { - tls := rc.tls.WithHost(rc.cfg.TiDB.PdAddr) + tls := rc.tls.WithHost(rc.pdCli.GetLeaderAddr()) return tikv.ForAllStores( ctx, tls, diff --git a/br/pkg/lightning/importer/precheck.go b/br/pkg/lightning/importer/precheck.go index 1658229321edb..735e17f163ca2 100644 --- a/br/pkg/lightning/importer/precheck.go +++ b/br/pkg/lightning/importer/precheck.go @@ -9,6 +9,7 @@ import ( ropts "github.com/pingcap/tidb/br/pkg/lightning/importer/opts" "github.com/pingcap/tidb/br/pkg/lightning/mydump" "github.com/pingcap/tidb/br/pkg/lightning/precheck" + pd "github.com/tikv/pd/client" ) type precheckContextKey string @@ -29,7 +30,8 @@ type PrecheckItemBuilder struct { } // NewPrecheckItemBuilderFromConfig creates a new PrecheckItemBuilder from config -func NewPrecheckItemBuilderFromConfig(ctx context.Context, cfg *config.Config, opts ...ropts.PrecheckItemBuilderOption) (*PrecheckItemBuilder, error) { +// pdCli **must not** be nil for local backend +func NewPrecheckItemBuilderFromConfig(ctx context.Context, cfg *config.Config, pdCli pd.Client, opts ...ropts.PrecheckItemBuilderOption) (*PrecheckItemBuilder, error) { var gerr error builderCfg := new(ropts.PrecheckItemBuilderConfig) for _, o := range opts { @@ -39,7 +41,7 @@ func NewPrecheckItemBuilderFromConfig(ctx context.Context, cfg *config.Config, o if err != nil { return nil, errors.Trace(err) } - targetInfoGetter, err := NewTargetInfoGetterImpl(cfg, targetDB) + targetInfoGetter, err := NewTargetInfoGetterImpl(cfg, targetDB, pdCli) if err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/lightning/importer/table_import_test.go b/br/pkg/lightning/importer/table_import_test.go index 39d105662f5ee..7db51e680f83d 100644 --- a/br/pkg/lightning/importer/table_import_test.go +++ b/br/pkg/lightning/importer/table_import_test.go @@ -69,6 +69,8 @@ import ( filter "github.com/pingcap/tidb/util/table-filter" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/tikv/client-go/v2/testutils" + pd "github.com/tikv/pd/client" ) const ( @@ -1161,6 +1163,8 @@ func (s *tableRestoreSuite) TestCheckClusterResource() { require.NoError(s.T(), err) mockStore, err := storage.NewLocalStorage(dir) require.NoError(s.T(), err) + _, _, pdClient, err := testutils.NewMockTiKV("", nil) + require.NoError(s.T(), err) for _, ca := range cases { server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { var err error @@ -1177,9 +1181,11 @@ func (s *tableRestoreSuite) TestCheckClusterResource() { url := strings.TrimPrefix(server.URL, "https://") cfg := &config.Config{TiDB: config.DBStore{PdAddr: url}} + pdCli := &mockPDClient{Client: pdClient, leaderAddr: url} targetInfoGetter := &TargetInfoGetterImpl{ - cfg: cfg, - tls: tls, + cfg: cfg, + tls: tls, + pdCli: pdCli, } preInfoGetter := &PreImportInfoGetterImpl{ cfg: cfg, @@ -1194,6 +1200,7 @@ func (s *tableRestoreSuite) TestCheckClusterResource() { checkTemplate: template, preInfoGetter: preInfoGetter, precheckItemBuilder: theCheckBuilder, + pdCli: pdCli, } var sourceSize int64 err = rc.store.WalkDir(ctx, &storage.WalkOption{}, func(path string, size int64) error { @@ -1230,6 +1237,15 @@ func (mockTaskMetaMgr) CheckTasksExclusively(ctx context.Context, action func(ta return err } +type mockPDClient struct { + pd.Client + leaderAddr string +} + +func (m *mockPDClient) GetLeaderAddr() string { + return m.leaderAddr +} + func (s *tableRestoreSuite) TestCheckClusterRegion() { type testCase struct { stores pdtypes.StoresInfo @@ -1245,6 +1261,8 @@ func (s *tableRestoreSuite) TestCheckClusterRegion() { } return regions } + _, _, pdClient, err := testutils.NewMockTiKV("", nil) + require.NoError(s.T(), err) testCases := []testCase{ { @@ -1320,10 +1338,12 @@ func (s *tableRestoreSuite) TestCheckClusterRegion() { url := strings.TrimPrefix(server.URL, "https://") cfg := &config.Config{TiDB: config.DBStore{PdAddr: url}} + pdCli := &mockPDClient{Client: pdClient, leaderAddr: url} targetInfoGetter := &TargetInfoGetterImpl{ - cfg: cfg, - tls: tls, + cfg: cfg, + tls: tls, + pdCli: pdCli, } dbMetas := []*mydump.MDDatabaseMeta{} preInfoGetter := &PreImportInfoGetterImpl{ @@ -1340,6 +1360,7 @@ func (s *tableRestoreSuite) TestCheckClusterRegion() { preInfoGetter: preInfoGetter, dbInfos: make(map[string]*checkpoints.TidbDBInfo), precheckItemBuilder: theCheckBuilder, + pdCli: pdCli, } preInfoGetter.dbInfosCache = rc.dbInfos diff --git a/disttask/importinto/dispatcher.go b/disttask/importinto/dispatcher.go index 6b839b4af591e..6c9e04b368cef 100644 --- a/disttask/importinto/dispatcher.go +++ b/disttask/importinto/dispatcher.go @@ -165,12 +165,13 @@ func (h *flowHandle) switchTiKVMode(ctx context.Context, task *proto.Task) { } logger := logutil.BgLogger().With(zap.Int64("task-id", task.ID)) - switcher, err := importer.GetTiKVModeSwitcher(logger) + pdCli, switcher, err := importer.GetTiKVModeSwitcherWithPDClient(ctx, logger) if err != nil { logger.Warn("get tikv mode switcher failed", zap.Error(err)) return } switcher.ToImportMode(ctx) + pdCli.Close() h.lastSwitchTime.Store(time.Now()) } @@ -335,12 +336,13 @@ func (h *flowHandle) switchTiKV2NormalMode(ctx context.Context, task *proto.Task h.mu.Lock() defer h.mu.Unlock() - switcher, err := importer.GetTiKVModeSwitcher(logger) + pdCli, switcher, err := importer.GetTiKVModeSwitcherWithPDClient(ctx, logger) if err != nil { logger.Warn("get tikv mode switcher failed", zap.Error(err)) return } switcher.ToNormalMode(ctx) + pdCli.Close() // clear it, so next task can switch TiKV mode again. h.lastSwitchTime.Store(time.Time{}) diff --git a/executor/importer/BUILD.bazel b/executor/importer/BUILD.bazel index 9b7041623f3c8..2939aa4b18881 100644 --- a/executor/importer/BUILD.bazel +++ b/executor/importer/BUILD.bazel @@ -65,6 +65,7 @@ go_library( "@com_github_tikv_client_go_v2//config", "@com_github_tikv_client_go_v2//tikv", "@com_github_tikv_client_go_v2//util", + "@com_github_tikv_pd_client//:client", "@org_golang_x_exp//slices", "@org_golang_x_sync//errgroup", "@org_uber_go_multierr//:multierr", diff --git a/executor/importer/table_import.go b/executor/importer/table_import.go index 549bcba46dd8a..25156eabb755a 100644 --- a/executor/importer/table_import.go +++ b/executor/importer/table_import.go @@ -45,6 +45,7 @@ import ( "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/syncutil" + pd "github.com/tikv/pd/client" "go.uber.org/multierr" "go.uber.org/zap" ) @@ -108,8 +109,8 @@ func prepareSortDir(e *LoadDataController, taskID int64, tidbCfg *tidb.Config) ( return sortDir, nil } -// GetTiKVModeSwitcher creates a new TiKV mode switcher. -func GetTiKVModeSwitcher(logger *zap.Logger) (local.TiKVModeSwitcher, error) { +// GetTiKVModeSwitcherWithPDClient creates a new TiKV mode switcher with its pd Client. +func GetTiKVModeSwitcherWithPDClient(ctx context.Context, logger *zap.Logger) (pd.Client, local.TiKVModeSwitcher, error) { tidbCfg := tidb.GetGlobalConfig() hostPort := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(tidbCfg.Status.StatusPort))) tls, err := common.NewTLS( @@ -120,9 +121,15 @@ func GetTiKVModeSwitcher(logger *zap.Logger) (local.TiKVModeSwitcher, error) { nil, nil, nil, ) if err != nil { - return nil, err + return nil, nil, err + } + tlsOpt := tls.ToPDSecurityOption() + pdCli, err := pd.NewClientWithContext(ctx, []string{tidbCfg.Path}, tlsOpt) + if err != nil { + return nil, nil, errors.Trace(err) } - return NewTiKVModeSwitcher(tls, tidbCfg.Path, logger), nil + + return pdCli, NewTiKVModeSwitcher(tls, pdCli, logger), nil } func getCachedKVStoreFrom(pdAddr string, tls *common.TLS) (tidbkv.Storage, error) { diff --git a/tests/realtikvtest/importintotest/BUILD.bazel b/tests/realtikvtest/importintotest/BUILD.bazel index ea5de77ac5ba1..79b23d30350c2 100644 --- a/tests/realtikvtest/importintotest/BUILD.bazel +++ b/tests/realtikvtest/importintotest/BUILD.bazel @@ -55,6 +55,7 @@ go_test( "@com_github_pingcap_log//:log", "@com_github_stretchr_testify//require", "@com_github_stretchr_testify//suite", + "@com_github_tikv_pd_client//:client", "@io_etcd_go_etcd_client_v3//:client", "@org_uber_go_atomic//:atomic", "@org_uber_go_zap//:zap", diff --git a/tests/realtikvtest/importintotest/import_into_test.go b/tests/realtikvtest/importintotest/import_into_test.go index eed08a60eb523..d5eeec4354d65 100644 --- a/tests/realtikvtest/importintotest/import_into_test.go +++ b/tests/realtikvtest/importintotest/import_into_test.go @@ -49,6 +49,7 @@ import ( "github.com/pingcap/tidb/util/dbterror/exeerrors" "github.com/pingcap/tidb/util/sem" "github.com/stretchr/testify/require" + pd "github.com/tikv/pd/client" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/atomic" "go.uber.org/zap" @@ -857,7 +858,7 @@ func (s *mockGCSSuite) TestImportMode() { switcher.EXPECT().ToImportMode(gomock.Any(), gomock.Any()).DoAndReturn(toImportModeFn).Times(1) switcher.EXPECT().ToNormalMode(gomock.Any(), gomock.Any()).DoAndReturn(toNormalModeFn).Times(1) backup := importer.NewTiKVModeSwitcher - importer.NewTiKVModeSwitcher = func(tls *common.TLS, pdAddr string, logger *zap.Logger) local.TiKVModeSwitcher { + importer.NewTiKVModeSwitcher = func(tls *common.TLS, pdCli pd.Client, logger *zap.Logger) local.TiKVModeSwitcher { return switcher } s.T().Cleanup(func() {