Skip to content

Commit

Permalink
lightning: fix pd http request using old address (#45680) (#45737)
Browse files Browse the repository at this point in the history
close #43436
  • Loading branch information
ti-chi-bot committed Aug 4, 2023
1 parent 41ae469 commit 203d839
Show file tree
Hide file tree
Showing 18 changed files with 149 additions and 45 deletions.
12 changes: 6 additions & 6 deletions br/pkg/lightning/backend/local/local.go
Expand Up @@ -268,15 +268,15 @@ func (*encodingBuilder) MakeEmptyRows() encode.Rows {
type targetInfoGetter struct {
tls *common.TLS
targetDB *sql.DB
pdAddr string
pdCli pd.Client
}

// NewTargetInfoGetter creates an TargetInfoGetter with local backend implementation.
func NewTargetInfoGetter(tls *common.TLS, db *sql.DB, pdAddr string) backend.TargetInfoGetter {
func NewTargetInfoGetter(tls *common.TLS, db *sql.DB, pdCli pd.Client) backend.TargetInfoGetter {
return &targetInfoGetter{
tls: tls,
targetDB: db,
pdAddr: pdAddr,
pdCli: pdCli,
}
}

Expand All @@ -297,10 +297,10 @@ func (g *targetInfoGetter) CheckRequirements(ctx context.Context, checkCtx *back
if err := checkTiDBVersion(ctx, versionStr, localMinTiDBVersion, localMaxTiDBVersion); err != nil {
return err
}
if err := tikv.CheckPDVersion(ctx, g.tls, g.pdAddr, localMinPDVersion, localMaxPDVersion); err != nil {
if err := tikv.CheckPDVersion(ctx, g.tls, g.pdCli.GetLeaderAddr(), localMinPDVersion, localMaxPDVersion); err != nil {
return err
}
if err := tikv.CheckTiKVVersion(ctx, g.tls, g.pdAddr, localMinTiKVVersion, localMaxTiKVVersion); err != nil {
if err := tikv.CheckTiKVVersion(ctx, g.tls, g.pdCli.GetLeaderAddr(), localMinTiKVVersion, localMaxTiKVVersion); err != nil {
return err
}

Expand Down Expand Up @@ -1719,7 +1719,7 @@ func (local *Backend) LocalWriter(_ context.Context, cfg *backend.LocalWriterCon
// This function will spawn a goroutine to keep switch mode periodically until the context is done.
// The return done channel is used to notify the caller that the background goroutine is exited.
func (local *Backend) SwitchModeByKeyRanges(ctx context.Context, ranges []Range) (<-chan struct{}, error) {
switcher := NewTiKVModeSwitcher(local.tls, local.PDAddr, log.FromContext(ctx).Logger)
switcher := NewTiKVModeSwitcher(local.tls, local.pdCtl.GetPDClient(), log.FromContext(ctx).Logger)
done := make(chan struct{})

keyRanges := make([]*sst.Range, 0, len(ranges))
Expand Down
9 changes: 5 additions & 4 deletions br/pkg/lightning/backend/local/tikv_mode.go
Expand Up @@ -20,6 +20,7 @@ import (
sstpb "github.com/pingcap/kvproto/pkg/import_sstpb"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/lightning/tikv"
pd "github.com/tikv/pd/client"
"go.uber.org/zap"
)

Expand All @@ -34,15 +35,15 @@ type TiKVModeSwitcher interface {
// TiKVModeSwitcher is used to switch TiKV nodes between Import and Normal mode.
type switcher struct {
tls *common.TLS
pdAddr string
pdCli pd.Client
logger *zap.Logger
}

// NewTiKVModeSwitcher creates a new TiKVModeSwitcher.
func NewTiKVModeSwitcher(tls *common.TLS, pdAddr string, logger *zap.Logger) TiKVModeSwitcher {
func NewTiKVModeSwitcher(tls *common.TLS, pdCli pd.Client, logger *zap.Logger) TiKVModeSwitcher {
return &switcher{
tls: tls,
pdAddr: pdAddr,
pdCli: pdCli,
logger: logger,
}
}
Expand All @@ -68,7 +69,7 @@ func (rc *switcher) switchTiKVMode(ctx context.Context, mode sstpb.SwitchMode, r
} else {
minState = tikv.StoreStateDisconnected
}
tls := rc.tls.WithHost(rc.pdAddr)
tls := rc.tls.WithHost(rc.pdCli.GetLeaderAddr())
// we ignore switch mode failure since it is not fatal.
// no need log the error, it is done in kv.SwitchMode already.
_ = tikv.ForAllStores(
Expand Down
2 changes: 1 addition & 1 deletion br/pkg/lightning/common/BUILD.bazel
Expand Up @@ -104,7 +104,7 @@ go_test(
],
embed = [":common"],
flaky = True,
shard_count = 19,
shard_count = 20,
deps = [
"//br/pkg/errors",
"//br/pkg/lightning/log",
Expand Down
10 changes: 10 additions & 0 deletions br/pkg/lightning/common/security.go
Expand Up @@ -20,6 +20,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/httputil"
Expand Down Expand Up @@ -88,9 +89,18 @@ func NewTLSFromMockServer(server *httptest.Server) *TLS {
}
}

// GetMockTLSUrl returns tls's host for mock test
func GetMockTLSUrl(tls *TLS) string {
return tls.url
}

// WithHost creates a new TLS instance with the host replaced.
func (tc *TLS) WithHost(host string) *TLS {
host = strings.TrimPrefix(host, "http://")
host = strings.TrimPrefix(host, "https://")
var url string
host = strings.TrimPrefix(host, "http://")
host = strings.TrimPrefix(host, "https://")
if tc.inner != nil {
url = "https://" + host
} else {
Expand Down
43 changes: 43 additions & 0 deletions br/pkg/lightning/common/security_test.go
Expand Up @@ -70,6 +70,49 @@ func TestGetJSONSecure(t *testing.T) {
require.Equal(t, "/dddd", result.Path)
}

func TestWithHost(t *testing.T) {
mockTLSServer := httptest.NewTLSServer(http.HandlerFunc(respondPathHandler))
defer mockTLSServer.Close()
mockServer := httptest.NewServer(http.HandlerFunc(respondPathHandler))
defer mockServer.Close()

testCases := []struct {
expected string
host string
secure bool
}{
{
"https://127.0.0.1:2379",
"http://127.0.0.1:2379",
true,
},
{
"http://127.0.0.1:2379",
"https://127.0.0.1:2379",
false,
},
{
"http://127.0.0.1:2379/pd/api/v1/stores",
"127.0.0.1:2379/pd/api/v1/stores",
false,
},
{
"https://127.0.0.1:2379",
"127.0.0.1:2379",
true,
},
}

for _, testCase := range testCases {
server := mockServer
if testCase.secure {
server = mockTLSServer
}
tls := common.NewTLSFromMockServer(server)
require.Equal(t, testCase.expected, common.GetMockTLSUrl(tls.WithHost(testCase.host)))
}
}

func TestInvalidTLS(t *testing.T) {
tempDir := t.TempDir()
caPath := filepath.Join(tempDir, "ca.pem")
Expand Down
2 changes: 2 additions & 0 deletions br/pkg/lightning/importer/BUILD.bazel
Expand Up @@ -175,6 +175,8 @@ go_test(
"@com_github_stretchr_testify//require",
"@com_github_stretchr_testify//suite",
"@com_github_tikv_client_go_v2//config",
"@com_github_tikv_client_go_v2//testutils",
"@com_github_tikv_pd_client//:client",
"@com_github_xitongsys_parquet_go//writer",
"@com_github_xitongsys_parquet_go_source//buffer",
"@io_etcd_go_etcd_client_v3//:client",
Expand Down
12 changes: 2 additions & 10 deletions br/pkg/lightning/importer/checksum_helper.go
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/pingcap/tidb/br/pkg/lightning/metric"
"github.com/pingcap/tidb/br/pkg/pdutil"
"github.com/pingcap/tidb/kv"
pd "github.com/tikv/pd/client"
"go.uber.org/zap"
)

Expand All @@ -37,21 +36,14 @@ func NewChecksumManager(ctx context.Context, rc *Controller, store kv.Storage) (
return nil, nil
}

pdAddr := rc.cfg.TiDB.PdAddr
pdVersion, err := pdutil.FetchPDVersion(ctx, rc.tls, pdAddr)
pdVersion, err := pdutil.FetchPDVersion(ctx, rc.tls, rc.pdCli.GetLeaderAddr())
if err != nil {
return nil, errors.Trace(err)
}

// for v4.0.0 or upper, we can use the gc ttl api
var manager local.ChecksumManager
if pdVersion.Major >= 4 && !rc.cfg.PostRestore.ChecksumViaSQL {
tlsOpt := rc.tls.ToPDSecurityOption()
pdCli, err := pd.NewClientWithContext(ctx, []string{pdAddr}, tlsOpt)
if err != nil {
return nil, errors.Trace(err)
}

backoffWeight, err := common.GetBackoffWeightFromDB(ctx, rc.db)
// only set backoff weight when it's smaller than default value
if err == nil && backoffWeight >= local.DefaultBackoffWeight {
Expand All @@ -66,7 +58,7 @@ func NewChecksumManager(ctx context.Context, rc *Controller, store kv.Storage) (
log.FromContext(ctx).Warn("get tidb_request_source_type failed", zap.Error(err), zap.String("tidb_request_source_type", explicitRequestSourceType))
return nil, errors.Trace(err)
}
manager = local.NewTiKVChecksumManager(store.GetClient(), pdCli, uint(rc.cfg.TiDB.DistSQLScanConcurrency), backoffWeight, rc.resourceGroupName, explicitRequestSourceType)
manager = local.NewTiKVChecksumManager(store.GetClient(), rc.pdCli, uint(rc.cfg.TiDB.DistSQLScanConcurrency), backoffWeight, rc.resourceGroupName, explicitRequestSourceType)
} else {
manager = local.NewTiDBChecksumExecutor(rc.db)
}
Expand Down
15 changes: 11 additions & 4 deletions br/pkg/lightning/importer/get_pre_info.go
Expand Up @@ -50,6 +50,7 @@ import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/dbterror"
"github.com/pingcap/tidb/util/mock"
pd "github.com/tikv/pd/client"
"go.uber.org/zap"
"golang.org/x/exp/maps"
)
Expand Down Expand Up @@ -123,12 +124,14 @@ type TargetInfoGetterImpl struct {
db *sql.DB
tls *common.TLS
backend backend.TargetInfoGetter
pdCli pd.Client
}

// NewTargetInfoGetterImpl creates a TargetInfoGetterImpl object.
func NewTargetInfoGetterImpl(
cfg *config.Config,
targetDB *sql.DB,
pdCli pd.Client,
) (*TargetInfoGetterImpl, error) {
tls, err := cfg.ToTLS()
if err != nil {
Expand All @@ -139,7 +142,10 @@ func NewTargetInfoGetterImpl(
case config.BackendTiDB:
backendTargetInfoGetter = tidb.NewTargetInfoGetter(targetDB)
case config.BackendLocal:
backendTargetInfoGetter = local.NewTargetInfoGetter(tls, targetDB, cfg.TiDB.PdAddr)
if pdCli == nil {
return nil, common.ErrUnknown.GenWithStack("pd client is required when using local backend")
}
backendTargetInfoGetter = local.NewTargetInfoGetter(tls, targetDB, pdCli)
default:
return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend)
}
Expand All @@ -148,6 +154,7 @@ func NewTargetInfoGetterImpl(
tls: tls,
db: targetDB,
backend: backendTargetInfoGetter,
pdCli: pdCli,
}, nil
}

Expand Down Expand Up @@ -229,7 +236,7 @@ func (g *TargetInfoGetterImpl) GetTargetSysVariablesForImport(ctx context.Contex
// It uses the PD interface through TLS to get the information.
func (g *TargetInfoGetterImpl) GetReplicationConfig(ctx context.Context) (*pdtypes.ReplicationConfig, error) {
result := new(pdtypes.ReplicationConfig)
if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdReplicate, &result); err != nil {
if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdReplicate, &result); err != nil {
return nil, errors.Trace(err)
}
return result, nil
Expand All @@ -240,7 +247,7 @@ func (g *TargetInfoGetterImpl) GetReplicationConfig(ctx context.Context) (*pdtyp
// It uses the PD interface through TLS to get the information.
func (g *TargetInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdtypes.StoresInfo, error) {
result := new(pdtypes.StoresInfo)
if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdStores, result); err != nil {
if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdStores, result); err != nil {
return nil, errors.Trace(err)
}
return result, nil
Expand All @@ -251,7 +258,7 @@ func (g *TargetInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdtypes.Sto
// It uses the PD interface through TLS to get the information.
func (g *TargetInfoGetterImpl) GetEmptyRegionsInfo(ctx context.Context) (*pdtypes.RegionsInfo, error) {
result := new(pdtypes.RegionsInfo)
if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdEmptyRegions, &result); err != nil {
if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdEmptyRegions, &result); err != nil {
return nil, errors.Trace(err)
}
return result, nil
Expand Down
5 changes: 4 additions & 1 deletion br/pkg/lightning/importer/get_pre_info_test.go
Expand Up @@ -757,7 +757,10 @@ func TestGetPreInfoIsTableEmpty(t *testing.T) {
require.NoError(t, err)
lnConfig := config.NewConfig()
lnConfig.TikvImporter.Backend = config.BackendLocal
targetGetter, err := NewTargetInfoGetterImpl(lnConfig, db)
_, err = NewTargetInfoGetterImpl(lnConfig, db, nil)
require.ErrorContains(t, err, "pd client is required when using local backend")
lnConfig.TikvImporter.Backend = config.BackendTiDB
targetGetter, err := NewTargetInfoGetterImpl(lnConfig, db, nil)
require.NoError(t, err)
require.Equal(t, lnConfig, targetGetter.cfg)

Expand Down
19 changes: 15 additions & 4 deletions br/pkg/lightning/importer/import.go
Expand Up @@ -200,6 +200,7 @@ type Controller struct {
engineMgr backend.EngineManager
backend backend.Backend
db *sql.DB
pdCli pd.Client

alterTableLock sync.Mutex
sysVars map[string]string
Expand Down Expand Up @@ -333,6 +334,7 @@ func NewImportControllerWithPauser(

var encodingBuilder encode.EncodingBuilder
var backendObj backend.Backend
var pdCli pd.Client
switch cfg.TikvImporter.Backend {
case config.BackendTiDB:
encodingBuilder = tidb.NewEncodingBuilder()
Expand All @@ -348,9 +350,13 @@ func NewImportControllerWithPauser(
if maxOpenFiles < 0 {
maxOpenFiles = math.MaxInt32
}
pdCli, err = pd.NewClientWithContext(ctx, []string{cfg.TiDB.PdAddr}, tls.ToPDSecurityOption())
if err != nil {
return nil, errors.Trace(err)
}

if cfg.TikvImporter.DuplicateResolution != config.DupeResAlgNone {
if err := tikv.CheckTiKVVersion(ctx, tls, cfg.TiDB.PdAddr, minTiKVVersionForDuplicateResolution, maxTiKVVersionForDuplicateResolution); err != nil {
if err := tikv.CheckTiKVVersion(ctx, tls, pdCli.GetLeaderAddr(), minTiKVVersionForDuplicateResolution, maxTiKVVersionForDuplicateResolution); err != nil {
if !berrors.Is(err, berrors.ErrVersionMismatch) {
return nil, common.ErrCheckKVVersion.Wrap(err).GenWithStackByArgs()
}
Expand Down Expand Up @@ -419,7 +425,7 @@ func NewImportControllerWithPauser(

var wrapper backend.TargetInfoGetter
if cfg.TikvImporter.Backend == config.BackendLocal {
wrapper = local.NewTargetInfoGetter(tls, db, cfg.TiDB.PdAddr)
wrapper = local.NewTargetInfoGetter(tls, db, pdCli)
} else {
wrapper = tidb.NewTargetInfoGetter(db)
}
Expand All @@ -429,6 +435,7 @@ func NewImportControllerWithPauser(
db: db,
tls: tls,
backend: wrapper,
pdCli: pdCli,
}
preInfoGetter, err := NewPreImportInfoGetter(
cfg,
Expand Down Expand Up @@ -458,6 +465,7 @@ func NewImportControllerWithPauser(
pauser: p.Pauser,
engineMgr: backend.MakeEngineManager(backendObj),
backend: backendObj,
pdCli: pdCli,
db: db,
sysVars: common.DefaultImportantVariables,
tls: tls,
Expand All @@ -482,7 +490,7 @@ func NewImportControllerWithPauser(
preInfoGetter: preInfoGetter,
precheckItemBuilder: preCheckBuilder,
encBuilder: encodingBuilder,
tikvModeSwitcher: local.NewTiKVModeSwitcher(tls, cfg.TiDB.PdAddr, log.FromContext(ctx).Logger),
tikvModeSwitcher: local.NewTiKVModeSwitcher(tls, pdCli, log.FromContext(ctx).Logger),

keyspaceName: p.KeyspaceName,
resourceGroupName: p.ResourceGroupName,
Expand All @@ -495,6 +503,9 @@ func NewImportControllerWithPauser(
func (rc *Controller) Close() {
rc.backend.Close()
_ = rc.db.Close()
if rc.pdCli != nil {
rc.pdCli.Close()
}
}

// Run starts the restore task.
Expand Down Expand Up @@ -1870,7 +1881,7 @@ func (rc *Controller) fullCompact(ctx context.Context) error {
}

func (rc *Controller) doCompact(ctx context.Context, level int32) error {
tls := rc.tls.WithHost(rc.cfg.TiDB.PdAddr)
tls := rc.tls.WithHost(rc.pdCli.GetLeaderAddr())
return tikv.ForAllStores(
ctx,
tls,
Expand Down

0 comments on commit 203d839

Please sign in to comment.