Skip to content

Commit

Permalink
lightning: fix pd http request using old address (#45680) (#45728)
Browse files Browse the repository at this point in the history
close #43436
  • Loading branch information
ti-chi-bot committed Aug 16, 2023
1 parent 874c044 commit 415c36c
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 19 deletions.
12 changes: 6 additions & 6 deletions br/pkg/lightning/backend/local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,15 @@ func (b *encodingBuilder) MakeEmptyRows() kv.Rows {
type targetInfoGetter struct {
tls *common.TLS
targetDBGlue glue.Glue
pdAddr string
pdCli pd.Client
}

// NewTargetInfoGetter creates an TargetInfoGetter with local backend implementation.
func NewTargetInfoGetter(tls *common.TLS, g glue.Glue, pdAddr string) backend.TargetInfoGetter {
func NewTargetInfoGetter(tls *common.TLS, g glue.Glue, pdCli pd.Client) backend.TargetInfoGetter {
return &targetInfoGetter{
tls: tls,
targetDBGlue: g,
pdAddr: pdAddr,
pdCli: pdCli,
}
}

Expand All @@ -264,10 +264,10 @@ func (g *targetInfoGetter) CheckRequirements(ctx context.Context, checkCtx *back
if err := checkTiDBVersion(ctx, versionStr, localMinTiDBVersion, localMaxTiDBVersion); err != nil {
return err
}
if err := tikv.CheckPDVersion(ctx, g.tls, g.pdAddr, localMinPDVersion, localMaxPDVersion); err != nil {
if err := tikv.CheckPDVersion(ctx, g.tls, g.pdCli.GetLeaderAddr(), localMinPDVersion, localMaxPDVersion); err != nil {
return err
}
if err := tikv.CheckTiKVVersion(ctx, g.tls, g.pdAddr, localMinTiKVVersion, localMaxTiKVVersion); err != nil {
if err := tikv.CheckTiKVVersion(ctx, g.tls, g.pdCli.GetLeaderAddr(), localMinTiKVVersion, localMaxTiKVVersion); err != nil {
return err
}

Expand Down Expand Up @@ -512,7 +512,7 @@ func NewLocalBackend(
writeLimiter: writeLimiter,
logger: log.FromContext(ctx),
encBuilder: NewEncodingBuilder(ctx),
targetInfoGetter: NewTargetInfoGetter(tls, g, cfg.TiDB.PdAddr),
targetInfoGetter: NewTargetInfoGetter(tls, g, pdCtl.GetPDClient()),
shouldCheckWriteStall: cfg.Cron.SwitchMode.Duration == 0,
}
if m, ok := metric.FromContext(ctx); ok {
Expand Down
8 changes: 8 additions & 0 deletions br/pkg/lightning/common/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/httputil"
Expand Down Expand Up @@ -86,8 +87,15 @@ func NewTLSFromMockServer(server *httptest.Server) *TLS {
}
}

// GetMockTLSUrl returns tls's host for mock test
func GetMockTLSUrl(tls *TLS) string {
return tls.url
}

// WithHost creates a new TLS instance with the host replaced.
func (tc *TLS) WithHost(host string) *TLS {
host = strings.TrimPrefix(host, "http://")
host = strings.TrimPrefix(host, "https://")
var url string
if tc.inner != nil {
url = "https://" + host
Expand Down
43 changes: 43 additions & 0 deletions br/pkg/lightning/common/security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,49 @@ func TestGetJSONSecure(t *testing.T) {
require.Equal(t, "/dddd", result.Path)
}

func TestWithHost(t *testing.T) {
mockTLSServer := httptest.NewTLSServer(http.HandlerFunc(respondPathHandler))
defer mockTLSServer.Close()
mockServer := httptest.NewServer(http.HandlerFunc(respondPathHandler))
defer mockServer.Close()

testCases := []struct {
expected string
host string
secure bool
}{
{
"https://127.0.0.1:2379",
"http://127.0.0.1:2379",
true,
},
{
"http://127.0.0.1:2379",
"https://127.0.0.1:2379",
false,
},
{
"http://127.0.0.1:2379/pd/api/v1/stores",
"127.0.0.1:2379/pd/api/v1/stores",
false,
},
{
"https://127.0.0.1:2379",
"127.0.0.1:2379",
true,
},
}

for _, testCase := range testCases {
server := mockServer
if testCase.secure {
server = mockTLSServer
}
tls := common.NewTLSFromMockServer(server)
require.Equal(t, testCase.expected, common.GetMockTLSUrl(tls.WithHost(testCase.host)))
}
}

func TestInvalidTLS(t *testing.T) {
tempDir := t.TempDir()
caPath := filepath.Join(tempDir, "ca.pem")
Expand Down
Empty file.
1 change: 1 addition & 0 deletions br/pkg/lightning/restore/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ go_test(
"@com_github_stretchr_testify//suite",
"@com_github_tikv_client_go_v2//config",
"@com_github_tikv_client_go_v2//oracle",
"@com_github_tikv_client_go_v2//testutils",
"@com_github_tikv_pd_client//:client",
"@com_github_xitongsys_parquet_go//writer",
"@com_github_xitongsys_parquet_go_source//buffer",
Expand Down
15 changes: 11 additions & 4 deletions br/pkg/lightning/restore/get_pre_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/dbterror"
"github.com/pingcap/tidb/util/mock"
pd "github.com/tikv/pd/client"
"go.uber.org/zap"
"golang.org/x/exp/maps"
)
Expand Down Expand Up @@ -117,12 +118,14 @@ type TargetInfoGetterImpl struct {
targetDBGlue glue.Glue
tls *common.TLS
backend backend.TargetInfoGetter
pdCli pd.Client
}

// NewTargetInfoGetterImpl creates a TargetInfoGetterImpl object.
func NewTargetInfoGetterImpl(
cfg *config.Config,
targetDB *sql.DB,
pdCli pd.Client,
) (*TargetInfoGetterImpl, error) {
targetDBGlue := glue.NewExternalTiDBGlue(targetDB, cfg.TiDB.SQLMode)
tls, err := cfg.ToTLS()
Expand All @@ -134,7 +137,10 @@ func NewTargetInfoGetterImpl(
case config.BackendTiDB:
backendTargetInfoGetter = tidb.NewTargetInfoGetter(targetDB)
case config.BackendLocal:
backendTargetInfoGetter = local.NewTargetInfoGetter(tls, targetDBGlue, cfg.TiDB.PdAddr)
if pdCli == nil {
return nil, common.ErrUnknown.GenWithStack("pd client is required when using local backend")
}
backendTargetInfoGetter = local.NewTargetInfoGetter(tls, targetDBGlue, pdCli)
default:
return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend)
}
Expand All @@ -143,6 +149,7 @@ func NewTargetInfoGetterImpl(
targetDBGlue: targetDBGlue,
tls: tls,
backend: backendTargetInfoGetter,
pdCli: pdCli,
}, nil
}

Expand Down Expand Up @@ -231,7 +238,7 @@ func (g *TargetInfoGetterImpl) GetTargetSysVariablesForImport(ctx context.Contex
// It uses the PD interface through TLS to get the information.
func (g *TargetInfoGetterImpl) GetReplicationConfig(ctx context.Context) (*pdtypes.ReplicationConfig, error) {
result := new(pdtypes.ReplicationConfig)
if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdReplicate, &result); err != nil {
if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdReplicate, &result); err != nil {
return nil, errors.Trace(err)
}
return result, nil
Expand All @@ -242,7 +249,7 @@ func (g *TargetInfoGetterImpl) GetReplicationConfig(ctx context.Context) (*pdtyp
// It uses the PD interface through TLS to get the information.
func (g *TargetInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdtypes.StoresInfo, error) {
result := new(pdtypes.StoresInfo)
if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdStores, result); err != nil {
if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdStores, result); err != nil {
return nil, errors.Trace(err)
}
return result, nil
Expand All @@ -253,7 +260,7 @@ func (g *TargetInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdtypes.Sto
// It uses the PD interface through TLS to get the information.
func (g *TargetInfoGetterImpl) GetEmptyRegionsInfo(ctx context.Context) (*pdtypes.RegionsInfo, error) {
result := new(pdtypes.RegionsInfo)
if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdEmptyRegions, &result); err != nil {
if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdEmptyRegions, &result); err != nil {
return nil, errors.Trace(err)
}
return result, nil
Expand Down
5 changes: 4 additions & 1 deletion br/pkg/lightning/restore/get_pre_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,10 @@ func TestGetPreInfoIsTableEmpty(t *testing.T) {
require.NoError(t, err)
lnConfig := config.NewConfig()
lnConfig.TikvImporter.Backend = config.BackendLocal
targetGetter, err := NewTargetInfoGetterImpl(lnConfig, db)
_, err = NewTargetInfoGetterImpl(lnConfig, db, nil)
require.ErrorContains(t, err, "pd client is required when using local backend")
lnConfig.TikvImporter.Backend = config.BackendTiDB
targetGetter, err := NewTargetInfoGetterImpl(lnConfig, db, nil)
require.NoError(t, err)
require.Equal(t, lnConfig, targetGetter.cfg)

Expand Down
5 changes: 3 additions & 2 deletions br/pkg/lightning/restore/precheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/pingcap/tidb/br/pkg/lightning/config"
"github.com/pingcap/tidb/br/pkg/lightning/mydump"
ropts "github.com/pingcap/tidb/br/pkg/lightning/restore/opts"
pd "github.com/tikv/pd/client"
)

type CheckItemID string
Expand Down Expand Up @@ -57,7 +58,7 @@ type PrecheckItemBuilder struct {
checkpointsDB checkpoints.DB
}

func NewPrecheckItemBuilderFromConfig(ctx context.Context, cfg *config.Config, opts ...ropts.PrecheckItemBuilderOption) (*PrecheckItemBuilder, error) {
func NewPrecheckItemBuilderFromConfig(ctx context.Context, cfg *config.Config, pdCli pd.Client, opts ...ropts.PrecheckItemBuilderOption) (*PrecheckItemBuilder, error) {
var gerr error
builderCfg := new(ropts.PrecheckItemBuilderConfig)
for _, o := range opts {
Expand All @@ -67,7 +68,7 @@ func NewPrecheckItemBuilderFromConfig(ctx context.Context, cfg *config.Config, o
if err != nil {
return nil, errors.Trace(err)
}
targetInfoGetter, err := NewTargetInfoGetterImpl(cfg, targetDB)
targetInfoGetter, err := NewTargetInfoGetterImpl(cfg, targetDB, pdCli)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
15 changes: 13 additions & 2 deletions br/pkg/lightning/restore/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ type Controller struct {
pauser *common.Pauser
backend backend.Backend
tidbGlue glue.Glue
pdCli pd.Client

alterTableLock sync.Mutex
sysVars map[string]string
Expand Down Expand Up @@ -329,6 +330,7 @@ func NewRestoreControllerWithPauser(
}

var backend backend.Backend
var pdCli pd.Client
switch cfg.TikvImporter.Backend {
case config.BackendTiDB:
backend = tidb.NewTiDBBackend(ctx, db, cfg.TikvImporter.OnDuplicate, errorMgr)
Expand All @@ -343,9 +345,13 @@ func NewRestoreControllerWithPauser(
if maxOpenFiles < 0 {
maxOpenFiles = math.MaxInt32
}
pdCli, err = pd.NewClientWithContext(ctx, []string{cfg.TiDB.PdAddr}, tls.ToPDSecurityOption())
if err != nil {
return nil, errors.Trace(err)
}

if cfg.TikvImporter.DuplicateResolution != config.DupeResAlgNone {
if err := tikv.CheckTiKVVersion(ctx, tls, cfg.TiDB.PdAddr, minTiKVVersionForDuplicateResolution, maxTiKVVersionForDuplicateResolution); err != nil {
if err := tikv.CheckTiKVVersion(ctx, tls, pdCli.GetLeaderAddr(), minTiKVVersionForDuplicateResolution, maxTiKVVersionForDuplicateResolution); err != nil {
if berrors.Is(err, berrors.ErrVersionMismatch) {
log.FromContext(ctx).Warn("TiKV version doesn't support duplicate resolution. The resolution algorithm will fall back to 'none'", zap.Error(err))
cfg.TikvImporter.DuplicateResolution = config.DupeResAlgNone
Expand Down Expand Up @@ -392,6 +398,7 @@ func NewRestoreControllerWithPauser(
targetDBGlue: p.Glue,
tls: tls,
backend: backend,
pdCli: pdCli,
}
preInfoGetter, err := NewPreRestoreInfoGetter(
cfg,
Expand Down Expand Up @@ -420,6 +427,7 @@ func NewRestoreControllerWithPauser(
checksumWorks: worker.NewPool(ctx, cfg.TiDB.ChecksumTableConcurrency, "checksum"),
pauser: p.Pauser,
backend: backend,
pdCli: pdCli,
tidbGlue: p.Glue,
sysVars: defaultImportantVariables,
tls: tls,
Expand Down Expand Up @@ -448,6 +456,9 @@ func NewRestoreControllerWithPauser(
func (rc *Controller) Close() {
rc.backend.Close()
rc.tidbGlue.GetSQLExecutor().Close()
if rc.pdCli != nil {
rc.pdCli.Close()
}
}

func (rc *Controller) Run(ctx context.Context) error {
Expand Down Expand Up @@ -1860,7 +1871,7 @@ func (rc *Controller) fullCompact(ctx context.Context) error {
}

func (rc *Controller) doCompact(ctx context.Context, level int32) error {
tls := rc.tls.WithHost(rc.cfg.TiDB.PdAddr)
tls := rc.tls.WithHost(rc.pdCli.GetLeaderAddr())
return tikv.ForAllStores(
ctx,
tls,
Expand Down
29 changes: 25 additions & 4 deletions br/pkg/lightning/restore/table_restore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ import (
filter "github.com/pingcap/tidb/util/table-filter"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/tikv/client-go/v2/testutils"
pd "github.com/tikv/pd/client"
)

type tableRestoreSuiteBase struct {
Expand Down Expand Up @@ -1099,6 +1101,8 @@ func (s *tableRestoreSuite) TestCheckClusterResource() {
require.NoError(s.T(), err)
mockStore, err := storage.NewLocalStorage(dir)
require.NoError(s.T(), err)
_, _, pdClient, err := testutils.NewMockTiKV("", nil)
require.NoError(s.T(), err)
for _, ca := range cases {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var err error
Expand All @@ -1115,9 +1119,11 @@ func (s *tableRestoreSuite) TestCheckClusterResource() {

url := strings.TrimPrefix(server.URL, "https://")
cfg := &config.Config{TiDB: config.DBStore{PdAddr: url}}
pdCli := &mockPDClient{Client: pdClient, leaderAddr: url}
targetInfoGetter := &TargetInfoGetterImpl{
cfg: cfg,
tls: tls,
cfg: cfg,
tls: tls,
pdCli: pdCli,
}
preInfoGetter := &PreRestoreInfoGetterImpl{
cfg: cfg,
Expand All @@ -1132,6 +1138,7 @@ func (s *tableRestoreSuite) TestCheckClusterResource() {
checkTemplate: template,
preInfoGetter: preInfoGetter,
precheckItemBuilder: theCheckBuilder,
pdCli: pdCli,
}
var sourceSize int64
err = rc.store.WalkDir(ctx, &storage.WalkOption{}, func(path string, size int64) error {
Expand Down Expand Up @@ -1168,6 +1175,15 @@ func (mockTaskMetaMgr) CheckTasksExclusively(ctx context.Context, action func(ta
return err
}

type mockPDClient struct {
pd.Client
leaderAddr string
}

func (m *mockPDClient) GetLeaderAddr() string {
return m.leaderAddr
}

func (s *tableRestoreSuite) TestCheckClusterRegion() {
type testCase struct {
stores pdtypes.StoresInfo
Expand All @@ -1184,6 +1200,8 @@ func (s *tableRestoreSuite) TestCheckClusterRegion() {
}
return regions
}
_, _, pdClient, err := testutils.NewMockTiKV("", nil)
require.NoError(s.T(), err)

testCases := []testCase{
{
Expand Down Expand Up @@ -1263,10 +1281,12 @@ func (s *tableRestoreSuite) TestCheckClusterRegion() {

url := strings.TrimPrefix(server.URL, "https://")
cfg := &config.Config{TiDB: config.DBStore{PdAddr: url}}
pdCli := &mockPDClient{Client: pdClient, leaderAddr: url}

targetInfoGetter := &TargetInfoGetterImpl{
cfg: cfg,
tls: tls,
cfg: cfg,
tls: tls,
pdCli: pdCli,
}
dbMetas := []*mydump.MDDatabaseMeta{}
preInfoGetter := &PreRestoreInfoGetterImpl{
Expand All @@ -1283,6 +1303,7 @@ func (s *tableRestoreSuite) TestCheckClusterRegion() {
preInfoGetter: preInfoGetter,
dbInfos: make(map[string]*checkpoints.TidbDBInfo),
precheckItemBuilder: theCheckBuilder,
pdCli: pdCli,
}

preInfoGetter.dbInfosCache = rc.dbInfos
Expand Down

0 comments on commit 415c36c

Please sign in to comment.